From ddd52d2dd66e3f01506cfc050a10df94c441bcde Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 19 May 2020 00:05:26 +0100 Subject: [PATCH] [Transform] scalar chk in MoveScalarLinearPastEltwiseAdd --- src/finn/transformation/streamline/reorder.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 254f9ab0e..96046602e 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -34,6 +34,8 @@ from finn.transformation.infer_shapes import InferShapes from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name +def is_scalar(x): + return np.prod(x.shape) == 1 class MoveAddPastMul(Transformation): """Move add operations past multiply operations. The aim is to have them @@ -326,6 +328,10 @@ class MoveScalarLinearPastEltwiseAdd(Transformation): # if either initializer is None, skip if init0 is None or init1 is None: continue + # if either initializer is non-scalar, skip + # TODO relax this to 1D tensors? + if (not is_scalar(init0)) or (not is_scalar(init1)): + continue if prod0.op_type == "Mul" and prod1.op_type == "Mul": if np.array_equal(init0, init1): self.move_node(graph, n, prod0, prod1, node_ind) -- GitLab