diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 96046602efb32a9262a4cf0bbb21a8367d719910..1886c785705161c3a13493de44dc3f3f86463f4f 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -34,8 +34,6 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
 
-def is_scalar(x):
-    return np.prod(x.shape) == 1
 
 class MoveAddPastMul(Transformation):
     """Move add operations past multiply operations. The aim is to have them
@@ -273,12 +271,12 @@ class MoveScalarMulPastConv(Transformation):
         return (model, graph_modified)
 
 
-class MoveScalarLinearPastEltwiseAdd(Transformation):
-    """Move scalar linear operations (mul, add) past elementwise add operations where possible. Specifically,
-       matches and transforms the following patterns:
+class MoveLinearPastEltwiseAdd(Transformation):
+    """Move linear operations (mul, add) past elementwise add operations where possible.
+       Specifically,matches and transforms the following patterns:
        (x*C) + (y*C) -> (x + y) * C
        (x+A) + (y+B) -> (x + y) + (A + B)
-       where x and y are dynamic inputs, A, B, C are constants.
+       where x and y are dynamic inputs, A, B, C are constant tensors (in general).
     """
 
     def move_node(self, graph, n, prod0, prod1, node_ind):
@@ -305,7 +303,8 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
         graph = model.graph
         node_ind = 0
         graph_modified = False
-        for n in graph.node:
+        nodes = [n for n in graph.node]
+        for n in nodes:
             node_ind += 1
             if n.op_type == "Add":
                 # check for tensors on both inputs (eltwise add)
@@ -321,17 +320,16 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
                 # check for mul with same initializer on both inputs
                 prod0 = model.find_producer(in0)
                 prod1 = model.find_producer(in1)
-                if prod0 is None or prod1 is None:
+                # Also check case when both branches are empty and come
+                # from the same node: (prod0 == prod1)
+                # Other transform should handle that
+                if prod0 is None or prod1 is None or (prod0 == prod1):
                     continue
                 init0 = model.get_initializer(prod0.input[1])
                 init1 = model.get_initializer(prod1.input[1])
                 # if either initializer is None, skip
                 if init0 is None or init1 is None:
                     continue
-                # if either initializer is non-scalar, skip
-                # TODO relax this to 1D tensors?
-                if (not is_scalar(init0)) or (not is_scalar(init1)):
-                    continue
                 if prod0.op_type == "Mul" and prod1.op_type == "Mul":
                     if np.array_equal(init0, init1):
                         self.move_node(graph, n, prod0, prod1, node_ind)
diff --git a/tests/transformation/test_scalar_past_eltwise.py b/tests/transformation/test_linear_past_eltwise.py
similarity index 69%
rename from tests/transformation/test_scalar_past_eltwise.py
rename to tests/transformation/test_linear_past_eltwise.py
index e845f32176a9293046b297b7d9e2ab64fabc1791..b77f59779a5e8559f80e017d13b66bcb67249830 100644
--- a/tests/transformation/test_scalar_past_eltwise.py
+++ b/tests/transformation/test_linear_past_eltwise.py
@@ -35,7 +35,7 @@ import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
-from finn.transformation.streamline.reorder import MoveScalarLinearPastEltwiseAdd
+from finn.transformation.streamline.reorder import MoveLinearPastEltwiseAdd
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.double_to_single_float import DoubleToSingleFloat
 
@@ -95,7 +95,7 @@ def make_model(shape):
 @pytest.mark.parametrize("ch", [64])
 # ifmdim
 @pytest.mark.parametrize("ifmdim", [-1, 7])
-def test_scalar_past_eltwise(ch, ifmdim):
+def test_linear_past_eltwise_add(ch, ifmdim):
     # generate test vectors of correct shape
     if ifmdim == -1:
         input_tensor_shape = (1, ch)
@@ -124,7 +124,7 @@ def test_scalar_past_eltwise(ch, ifmdim):
     assert len(model.get_nodes_by_op_type("Add")) == 3
     assert len(model.get_nodes_by_op_type("Mul")) == 2
 
-    model = model.transform(MoveScalarLinearPastEltwiseAdd())
+    model = model.transform(MoveLinearPastEltwiseAdd())
 
     # verify again, to check we didnt break anything
     output_dict = oxe.execute_onnx(model, input_dict, True)
@@ -134,3 +134,68 @@ def test_scalar_past_eltwise(ch, ifmdim):
     assert len(model.get_nodes_by_op_type("Mul")) == 1
 
     os.remove(export_onnx_path)
+
+
+@pytest.mark.parametrize("ch", [64, 1])
+# ifmdim
+@pytest.mark.parametrize("ifmdim", [-1, 7])
+def test_linear_past_eltwise_add_multiple_forks(ch, ifmdim):
+    # generate test vectors of correct shape
+    if ifmdim == -1:
+        input_shape = (1, ch)
+    else:
+        input_shape = (1, ch, ifmdim, ifmdim)
+
+    top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
+
+    num_of_params = 6
+    value_info = []
+    for i in range(num_of_params):
+        value_info += [
+            helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
+        ]
+
+    modelproto = helper.make_model(
+        helper.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                helper.make_node("Add", ["top_in", "p0"], ["fork1"]),
+                helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
+                helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
+                helper.make_node("Add", ["t2", "t3"], ["t4"]),
+                helper.make_node("Mul", ["t4", "p3"], ["fork2"]),
+                helper.make_node("Add", ["fork2", "p4"], ["t5"]),
+                helper.make_node("Add", ["fork2", "p5"], ["t6"]),
+                helper.make_node("Add", ["t5", "t6"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    for i in range(num_of_params):
+        model.set_initializer(
+            "p" + str(i), np.random.rand(*input_shape).astype(np.float32)
+        )
+
+    # need equal mults:
+    model.set_initializer("p2", model.get_initializer("p1"))
+
+    # Transform
+    new_model = model.transform(MoveLinearPastEltwiseAdd())
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    # Test
+    assert oxe.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Add"
+    assert new_model.graph.node[1].op_type == "Add"
+    assert new_model.graph.node[2].op_type == "Mul"
+    assert new_model.graph.node[3].op_type == "Mul"
+    assert new_model.graph.node[4].op_type == "Add"
+    assert new_model.graph.node[5].op_type == "Add"
+    assert len(new_model.graph.node) == 6