diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 85354fac87f38c6a0ae424f3aeec24a72a36aad0..0c800afa3021c22a4c7840d53c4c4cedcc0a72d5 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -360,3 +360,90 @@ class MakeMaxPoolNHWC(Transformation):
                         graph.node.insert(node_ind - 1, consumer)
                         graph_modified = True
         return (model, graph_modified)
+
+
+class MoveOpPastFork(Transformation):
+    """Move node operations past graph forks. Used when a node before a fork
+     can be merged with nodes in the branches
+    """
+
+    def __init__(self, op_name_list):
+        super().__init__()
+        self.ops_to_move = op_name_list
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        nodes = [n for n in graph.node]
+        node_ind = 0
+        for n in nodes:
+            node_ind += 1
+            if (
+                n.op_type in self.ops_to_move
+                and model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
+
+                # Restrict this transform to operations with constant parameters
+                # Assuming parameters is in input 1
+                op_init_param = model.get_initializer(n.input[1])
+                if op_init_param is None:
+                    continue
+
+                # Check case when branches are empty and go
+                # to the same node
+                consumers = model.find_consumers(n.output[0])
+                unique_consumer = True
+                for consum_node in consumers[1:]:
+                    if consumers[0] != consum_node:
+                        unique_consumer = False
+                        break
+
+                if unique_consumer:
+                    continue
+
+                for consumer_node in consumers[1:]:
+                    # create new node
+                    new_param_name = model.make_new_valueinfo_name()
+                    new_output_tensor_name = model.make_new_valueinfo_name()
+                    new_node = oh.make_node(
+                        n.op_type,
+                        [n.input[0], new_param_name],
+                        [new_output_tensor_name],
+                    )
+                    graph.node.insert(node_ind, new_node)
+                    node_ind += 1
+                    model.set_initializer(new_param_name, op_init_param)
+
+                    # change consumer input tensor
+                    graph.node.remove(consumer_node)
+                    for idx, consumer_input in enumerate(consumer_node.input):
+                        if consumer_input == n.output[0]:
+                            consumer_node.input[idx] = new_output_tensor_name
+                            break
+                    else:
+                        raise Exception(
+                            "Consumer should have the current node output as input"
+                        )
+
+                    graph.node.insert(node_ind, consumer_node)
+
+                graph_modified = True
+
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
+class MoveAddPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Add"])
+
+
+class MoveMulPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Mul"])
+
+
+class MoveLinearPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Add", "Mul"])