From ddd52d2dd66e3f01506cfc050a10df94c441bcde Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 19 May 2020 00:05:26 +0100
Subject: [PATCH] [Transform] scalar chk in MoveScalarLinearPastEltwiseAdd

---
 src/finn/transformation/streamline/reorder.py | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 254f9ab0e..96046602e 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -34,6 +34,8 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
 
+def is_scalar(x):
+    return np.prod(x.shape) == 1
 
 class MoveAddPastMul(Transformation):
     """Move add operations past multiply operations. The aim is to have them
@@ -326,6 +328,10 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
                 # if either initializer is None, skip
                 if init0 is None or init1 is None:
                     continue
+                # if either initializer is non-scalar, skip
+                # TODO relax this to 1D tensors?
+                if (not is_scalar(init0)) or (not is_scalar(init1)):
+                    continue
                 if prod0.op_type == "Mul" and prod1.op_type == "Mul":
                     if np.array_equal(init0, init1):
                         self.move_node(graph, n, prod0, prod1, node_ind)
-- 
GitLab