Skip to content
Snippets Groups Projects
Unverified Commit 316ca19d authored by auphelia's avatar auphelia Committed by GitHub
Browse files

Merge pull request #121 from quetric/feature/absorb_transform_only_if_linear

Feature/absorb transform only if linear
parents 762c12a1 280c9baa
No related branches found
No related tags found
No related merge requests found
......@@ -46,7 +46,11 @@ class AbsorbAddIntoMultiThreshold(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Add":
if (
n.op_type == "Add"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
add_weight_name = n.input[1]
......@@ -83,7 +87,11 @@ class AbsorbMulIntoMultiThreshold(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Mul":
if (
n.op_type == "Mul"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment