diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index ec58084df1589aaa4db5e154832b3fbc4eddb9de..0b6259a61d3eb67b7b38d4c6939019ce2893a875 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -444,3 +444,90 @@ class MakeMaxPoolNHWC(Transformation):
                         graph.node.insert(node_ind - 1, consumer)
                         graph_modified = True
         return (model, graph_modified)
+
+
+class MoveOpPastFork(Transformation):
+    """Move node operations past graph forks. Used when a node before a fork
+     can be merged with nodes in the branches
+    """
+
+    def __init__(self, op_name_list):
+        super().__init__()
+        self.ops_to_move = op_name_list
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        nodes = [n for n in graph.node]
+        node_ind = 0
+        for n in nodes:
+            node_ind += 1
+            if (
+                n.op_type in self.ops_to_move
+                and model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
+
+                # Restrict this transform to operations with constant parameters
+                # Assuming parameters is in input 1
+                op_init_param = model.get_initializer(n.input[1])
+                if op_init_param is None:
+                    continue
+
+                # Check case when branches are empty and go
+                # to the same node
+                consumers = model.find_consumers(n.output[0])
+                unique_consumer = True
+                for consum_node in consumers[1:]:
+                    if consumers[0] != consum_node:
+                        unique_consumer = False
+                        break
+
+                if unique_consumer:
+                    continue
+
+                for consumer_node in consumers[1:]:
+                    # create new node
+                    new_param_name = model.make_new_valueinfo_name()
+                    new_output_tensor_name = model.make_new_valueinfo_name()
+                    new_node = oh.make_node(
+                        n.op_type,
+                        [n.input[0], new_param_name],
+                        [new_output_tensor_name],
+                    )
+                    graph.node.insert(node_ind, new_node)
+                    node_ind += 1
+                    model.set_initializer(new_param_name, op_init_param)
+
+                    # change consumer input tensor
+                    graph.node.remove(consumer_node)
+                    for idx, consumer_input in enumerate(consumer_node.input):
+                        if consumer_input == n.output[0]:
+                            consumer_node.input[idx] = new_output_tensor_name
+                            break
+                    else:
+                        raise Exception(
+                            "Consumer should have the current node output as input"
+                        )
+
+                    graph.node.insert(node_ind, consumer_node)
+
+                graph_modified = True
+
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
+class MoveAddPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Add"])
+
+
+class MoveMulPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Mul"])
+
+
+class MoveLinearPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Add", "Mul"])
diff --git a/tests/transformation/test_move_past_fork.py b/tests/transformation/test_move_past_fork.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3d37bd60c9e2580ca4499daafa8693f39fec810
--- /dev/null
+++ b/tests/transformation/test_move_past_fork.py
@@ -0,0 +1,79 @@
+from onnx import TensorProto, helper
+import numpy as np
+
+import finn.core.onnx_exec as oxe
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.streamline.reorder import MoveLinearPastFork
+from finn.transformation.infer_shapes import InferShapes
+
+import pytest
+
+
+@pytest.mark.parametrize("ch", [64, 1])
+# ifmdim
+@pytest.mark.parametrize("ifmdim", [-1, 7])
+def test_move_past_fork(ch, ifmdim):
+    # generate test vectors of correct shape
+    if ifmdim == -1:
+        input_shape = (1, ch)
+    else:
+        input_shape = (1, ch, ifmdim, ifmdim)
+
+    top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
+
+    num_of_params = 8
+    value_info = []
+    for i in range(num_of_params):
+        value_info += [
+            helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
+        ]
+
+    add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"])
+    mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"])
+    add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"])
+    mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"])
+    modelproto = helper.make_model(
+        helper.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                # fork1
+                add_1_to_move,
+                helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
+                helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
+                helper.make_node("Add", ["t2", "t3"], ["t4"]),
+                helper.make_node("Add", ["t4", "p3"], ["t5"]),
+                # fork2
+                mul_1_to_move,
+                add_2_to_move,
+                helper.make_node("Add", ["fork2", "p6"], ["t7"]),
+                helper.make_node("Add", ["t6", "t7"], ["t8"]),
+                # empty branches: do nothing
+                mul_1_not_to_move,
+                helper.make_node("Add", ["fork3", "fork3"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    for i in range(num_of_params):
+        model.set_initializer(
+            "p" + str(i), np.random.rand(*input_shape).astype(np.float32)
+        )
+
+    # Transform
+    new_model = model.transform(MoveLinearPastFork())
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    # Test
+    assert oxe.compare_execution(model, new_model, inp_dict)
+    assert not new_model.is_fork_node(add_1_to_move)
+    assert not new_model.is_fork_node(mul_1_to_move)
+    assert not new_model.is_fork_node(add_2_to_move)
+    assert new_model.is_fork_node(mul_1_not_to_move)
+    assert len(new_model.graph.node) == 14