Skip to content
Snippets Groups Projects
Commit 188b6ac5 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[TRANSFORM] Generalize from past scalar to past any linear op (this was...

[TRANSFORM] Generalize from past scalar to past any linear op (this was supported, but restricted). Fix node sorting problem, by iterating over a copy list of the node instead of iteration over graph.node. Fix corner case when both branches are empty and come from the same node
parent 4062647c
No related branches found
No related tags found
No related merge requests found
...@@ -34,8 +34,6 @@ from finn.transformation.infer_shapes import InferShapes ...@@ -34,8 +34,6 @@ from finn.transformation.infer_shapes import InferShapes
from finn.core.onnx_exec import execute_node from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name from finn.util.basic import get_by_name
def is_scalar(x):
return np.prod(x.shape) == 1
class MoveAddPastMul(Transformation): class MoveAddPastMul(Transformation):
"""Move add operations past multiply operations. The aim is to have them """Move add operations past multiply operations. The aim is to have them
...@@ -273,12 +271,12 @@ class MoveScalarMulPastConv(Transformation): ...@@ -273,12 +271,12 @@ class MoveScalarMulPastConv(Transformation):
return (model, graph_modified) return (model, graph_modified)
class MoveScalarLinearPastEltwiseAdd(Transformation): class MoveLinearPastEltwiseAdd(Transformation):
"""Move scalar linear operations (mul, add) past elementwise add operations where possible. Specifically, """Move linear operations (mul, add) past elementwise add operations where possible.
matches and transforms the following patterns: Specifically,matches and transforms the following patterns:
(x*C) + (y*C) -> (x + y) * C (x*C) + (y*C) -> (x + y) * C
(x+A) + (y+B) -> (x + y) + (A + B) (x+A) + (y+B) -> (x + y) + (A + B)
where x and y are dynamic inputs, A, B, C are constants. where x and y are dynamic inputs, A, B, C are constant tensors (in general).
""" """
def move_node(self, graph, n, prod0, prod1, node_ind): def move_node(self, graph, n, prod0, prod1, node_ind):
...@@ -305,7 +303,8 @@ class MoveScalarLinearPastEltwiseAdd(Transformation): ...@@ -305,7 +303,8 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
graph = model.graph graph = model.graph
node_ind = 0 node_ind = 0
graph_modified = False graph_modified = False
for n in graph.node: nodes = [n for n in graph.node]
for n in nodes:
node_ind += 1 node_ind += 1
if n.op_type == "Add": if n.op_type == "Add":
# check for tensors on both inputs (eltwise add) # check for tensors on both inputs (eltwise add)
...@@ -321,17 +320,16 @@ class MoveScalarLinearPastEltwiseAdd(Transformation): ...@@ -321,17 +320,16 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
# check for mul with same initializer on both inputs # check for mul with same initializer on both inputs
prod0 = model.find_producer(in0) prod0 = model.find_producer(in0)
prod1 = model.find_producer(in1) prod1 = model.find_producer(in1)
if prod0 is None or prod1 is None: # Also check case when both branches are empty and come
# from the same node: (prod0 == prod1)
# Other transform should handle that
if prod0 is None or prod1 is None or (prod0 == prod1):
continue continue
init0 = model.get_initializer(prod0.input[1]) init0 = model.get_initializer(prod0.input[1])
init1 = model.get_initializer(prod1.input[1]) init1 = model.get_initializer(prod1.input[1])
# if either initializer is None, skip # if either initializer is None, skip
if init0 is None or init1 is None: if init0 is None or init1 is None:
continue continue
# if either initializer is non-scalar, skip
# TODO relax this to 1D tensors?
if (not is_scalar(init0)) or (not is_scalar(init1)):
continue
if prod0.op_type == "Mul" and prod1.op_type == "Mul": if prod0.op_type == "Mul" and prod1.op_type == "Mul":
if np.array_equal(init0, init1): if np.array_equal(init0, init1):
self.move_node(graph, n, prod0, prod1, node_ind) self.move_node(graph, n, prod0, prod1, node_ind)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment