diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 96046602efb32a9262a4cf0bbb21a8367d719910..1886c785705161c3a13493de44dc3f3f86463f4f 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -34,8 +34,6 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
 
-def is_scalar(x):
-    return np.prod(x.shape) == 1
 
 class MoveAddPastMul(Transformation):
     """Move add operations past multiply operations. The aim is to have them
@@ -273,12 +271,12 @@ class MoveScalarMulPastConv(Transformation):
         return (model, graph_modified)
 
 
-class MoveScalarLinearPastEltwiseAdd(Transformation):
-    """Move scalar linear operations (mul, add) past elementwise add operations where possible. Specifically,
-       matches and transforms the following patterns:
+class MoveLinearPastEltwiseAdd(Transformation):
+    """Move linear operations (mul, add) past elementwise add operations where possible.
+       Specifically,matches and transforms the following patterns:
        (x*C) + (y*C) -> (x + y) * C
        (x+A) + (y+B) -> (x + y) + (A + B)
-       where x and y are dynamic inputs, A, B, C are constants.
+       where x and y are dynamic inputs, A, B, C are constant tensors (in general).
     """
 
     def move_node(self, graph, n, prod0, prod1, node_ind):
@@ -305,7 +303,8 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
         graph = model.graph
         node_ind = 0
         graph_modified = False
-        for n in graph.node:
+        nodes = [n for n in graph.node]
+        for n in nodes:
             node_ind += 1
             if n.op_type == "Add":
                 # check for tensors on both inputs (eltwise add)
@@ -321,17 +320,16 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
                 # check for mul with same initializer on both inputs
                 prod0 = model.find_producer(in0)
                 prod1 = model.find_producer(in1)
-                if prod0 is None or prod1 is None:
+                # Also check case when both branches are empty and come
+                # from the same node: (prod0 == prod1)
+                # Other transform should handle that
+                if prod0 is None or prod1 is None or (prod0 == prod1):
                     continue
                 init0 = model.get_initializer(prod0.input[1])
                 init1 = model.get_initializer(prod1.input[1])
                 # if either initializer is None, skip
                 if init0 is None or init1 is None:
                     continue
-                # if either initializer is non-scalar, skip
-                # TODO relax this to 1D tensors?
-                if (not is_scalar(init0)) or (not is_scalar(init1)):
-                    continue
                 if prod0.op_type == "Mul" and prod1.op_type == "Mul":
                     if np.array_equal(init0, init1):
                         self.move_node(graph, n, prod0, prod1, node_ind)