From 188b6ac5709ef14d13d64b6eb498c0f42d73c6b2 Mon Sep 17 00:00:00 2001 From: Tobi-Alonso <tobi.alonso@gmail.com> Date: Tue, 19 May 2020 16:10:19 +0100 Subject: [PATCH] [TRANSFORM] Generalize from past scalar to past any linear op (this was supported, but restricted). Fix node sorting problem, by iterating over a copy list of the node instead of iteration over graph.node. Fix corner case when both branches are empty and come from the same node --- src/finn/transformation/streamline/reorder.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 96046602e..1886c7857 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -34,8 +34,6 @@ from finn.transformation.infer_shapes import InferShapes from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name -def is_scalar(x): - return np.prod(x.shape) == 1 class MoveAddPastMul(Transformation): """Move add operations past multiply operations. The aim is to have them @@ -273,12 +271,12 @@ class MoveScalarMulPastConv(Transformation): return (model, graph_modified) -class MoveScalarLinearPastEltwiseAdd(Transformation): - """Move scalar linear operations (mul, add) past elementwise add operations where possible. Specifically, - matches and transforms the following patterns: +class MoveLinearPastEltwiseAdd(Transformation): + """Move linear operations (mul, add) past elementwise add operations where possible. + Specifically,matches and transforms the following patterns: (x*C) + (y*C) -> (x + y) * C (x+A) + (y+B) -> (x + y) + (A + B) - where x and y are dynamic inputs, A, B, C are constants. + where x and y are dynamic inputs, A, B, C are constant tensors (in general). """ def move_node(self, graph, n, prod0, prod1, node_ind): @@ -305,7 +303,8 @@ class MoveScalarLinearPastEltwiseAdd(Transformation): graph = model.graph node_ind = 0 graph_modified = False - for n in graph.node: + nodes = [n for n in graph.node] + for n in nodes: node_ind += 1 if n.op_type == "Add": # check for tensors on both inputs (eltwise add) @@ -321,17 +320,16 @@ class MoveScalarLinearPastEltwiseAdd(Transformation): # check for mul with same initializer on both inputs prod0 = model.find_producer(in0) prod1 = model.find_producer(in1) - if prod0 is None or prod1 is None: + # Also check case when both branches are empty and come + # from the same node: (prod0 == prod1) + # Other transform should handle that + if prod0 is None or prod1 is None or (prod0 == prod1): continue init0 = model.get_initializer(prod0.input[1]) init1 = model.get_initializer(prod1.input[1]) # if either initializer is None, skip if init0 is None or init1 is None: continue - # if either initializer is non-scalar, skip - # TODO relax this to 1D tensors? - if (not is_scalar(init0)) or (not is_scalar(init1)): - continue if prod0.op_type == "Mul" and prod1.op_type == "Mul": if np.array_equal(init0, init1): self.move_node(graph, n, prod0, prod1, node_ind) -- GitLab