diff --git a/tests/transformation/test_scalar_past_eltwise.py b/tests/transformation/test_linear_past_eltwise.py
similarity index 69%
rename from tests/transformation/test_scalar_past_eltwise.py
rename to tests/transformation/test_linear_past_eltwise.py
index e845f32176a9293046b297b7d9e2ab64fabc1791..b77f59779a5e8559f80e017d13b66bcb67249830 100644
--- a/tests/transformation/test_scalar_past_eltwise.py
+++ b/tests/transformation/test_linear_past_eltwise.py
@@ -35,7 +35,7 @@ import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
-from finn.transformation.streamline.reorder import MoveScalarLinearPastEltwiseAdd
+from finn.transformation.streamline.reorder import MoveLinearPastEltwiseAdd
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.double_to_single_float import DoubleToSingleFloat
 
@@ -95,7 +95,7 @@ def make_model(shape):
 @pytest.mark.parametrize("ch", [64])
 # ifmdim
 @pytest.mark.parametrize("ifmdim", [-1, 7])
-def test_scalar_past_eltwise(ch, ifmdim):
+def test_linear_past_eltwise_add(ch, ifmdim):
     # generate test vectors of correct shape
     if ifmdim == -1:
         input_tensor_shape = (1, ch)
@@ -124,7 +124,7 @@ def test_scalar_past_eltwise(ch, ifmdim):
     assert len(model.get_nodes_by_op_type("Add")) == 3
     assert len(model.get_nodes_by_op_type("Mul")) == 2
 
-    model = model.transform(MoveScalarLinearPastEltwiseAdd())
+    model = model.transform(MoveLinearPastEltwiseAdd())
 
     # verify again, to check we didnt break anything
     output_dict = oxe.execute_onnx(model, input_dict, True)
@@ -134,3 +134,68 @@ def test_scalar_past_eltwise(ch, ifmdim):
     assert len(model.get_nodes_by_op_type("Mul")) == 1
 
     os.remove(export_onnx_path)
+
+
+@pytest.mark.parametrize("ch", [64, 1])
+# ifmdim
+@pytest.mark.parametrize("ifmdim", [-1, 7])
+def test_linear_past_eltwise_add_multiple_forks(ch, ifmdim):
+    # generate test vectors of correct shape
+    if ifmdim == -1:
+        input_shape = (1, ch)
+    else:
+        input_shape = (1, ch, ifmdim, ifmdim)
+
+    top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
+
+    num_of_params = 6
+    value_info = []
+    for i in range(num_of_params):
+        value_info += [
+            helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
+        ]
+
+    modelproto = helper.make_model(
+        helper.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                helper.make_node("Add", ["top_in", "p0"], ["fork1"]),
+                helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
+                helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
+                helper.make_node("Add", ["t2", "t3"], ["t4"]),
+                helper.make_node("Mul", ["t4", "p3"], ["fork2"]),
+                helper.make_node("Add", ["fork2", "p4"], ["t5"]),
+                helper.make_node("Add", ["fork2", "p5"], ["t6"]),
+                helper.make_node("Add", ["t5", "t6"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    for i in range(num_of_params):
+        model.set_initializer(
+            "p" + str(i), np.random.rand(*input_shape).astype(np.float32)
+        )
+
+    # need equal mults:
+    model.set_initializer("p2", model.get_initializer("p1"))
+
+    # Transform
+    new_model = model.transform(MoveLinearPastEltwiseAdd())
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    # Test
+    assert oxe.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Add"
+    assert new_model.graph.node[1].op_type == "Add"
+    assert new_model.graph.node[2].op_type == "Mul"
+    assert new_model.graph.node[3].op_type == "Mul"
+    assert new_model.graph.node[4].op_type == "Add"
+    assert new_model.graph.node[5].op_type == "Add"
+    assert len(new_model.graph.node) == 6