diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index cdf99dc3bd8b698bec60d79ef6e34640ac3b740c..ed32426abcc8ea71428a7f746a99454e8e4a2c17 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -333,6 +333,22 @@ class ModelWrapper:
         else:
             return None
 
+    def is_fork_node(self, node):
+        """Checks if the given node is a fork, that is, the node has multiple
+        direct successors"""
+        direct_successors = self.find_direct_successors(node)
+        is_fork = False if direct_successors is None else (len(direct_successors) > 1)
+        return is_fork
+
+    def is_join_node(self, node):
+        """Checks if the given node is a join, that is, the node has multiple
+        direct predecessors"""
+        direct_predecessors = self.find_direct_predecessors(node)
+        is_join = (
+            False if direct_predecessors is None else (len(direct_predecessors) > 1)
+        )
+        return is_join
+
     def get_all_tensor_names(self):
         """Returns a list of all (input, output and value_info) tensor names
         in the graph."""
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 1886c785705161c3a13493de44dc3f3f86463f4f..b91ffdb3f731d27d9a6ba68b090f3881e6d7293a 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -36,8 +36,9 @@ from finn.util.basic import get_by_name
 
 
 class MoveAddPastMul(Transformation):
-    """Move add operations past multiply operations. The aim is to have them
-    next to each other such that they can be collapsed into a single add."""
+    """Move add operations past multiply operations on linear segments of the graph.
+    The aim is to have them next to each other such that they can be collapsed into
+    a single add."""
 
     def apply(self, model):
         graph = model.graph
@@ -45,9 +46,17 @@ class MoveAddPastMul(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Add":
+            if (
+                n.op_type == "Add"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "Mul":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "Mul"
+                    and not model.is_join_node(consumer)
+                ):
                     # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
                     # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
                     # assume input 0 is from the previous layer, input 1 is the
@@ -63,12 +72,16 @@ class MoveAddPastMul(Transformation):
                     end_name = consumer.output[0]
                     # compute new param value for add
                     BA = B * A
+
                     # make and insert new nodes
                     new_mul = oh.make_node(
-                        "Mul", [start_name, mul_weight_name], [middle_name]
+                        "Mul",
+                        [start_name, mul_weight_name],
+                        [middle_name],
+                        name=consumer.name,
                     )
                     new_add = oh.make_node(
-                        "Add", [middle_name, add_weight_name], [end_name]
+                        "Add", [middle_name, add_weight_name], [end_name], name=n.name
                     )
                     graph.node.insert(node_ind, new_mul)
                     graph.node.insert(node_ind + 1, new_add)
@@ -78,6 +91,7 @@ class MoveAddPastMul(Transformation):
                     graph.node.remove(n)
                     graph.node.remove(consumer)
                     graph_modified = True
+
         model = model.transform(InferShapes())
         return (model, graph_modified)
 
@@ -92,9 +106,17 @@ class MoveScalarMulPastMatMul(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Mul":
+            if (
+                n.op_type == "Mul"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "MatMul":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "MatMul"
+                    and not model.is_join_node(consumer)
+                ):
                     mul_weight_name = n.input[1]
                     matmul_weight_name = consumer.input[1]
                     A = model.get_initializer(mul_weight_name)
@@ -109,10 +131,16 @@ class MoveScalarMulPastMatMul(Transformation):
                         # if the mul is scalar, we can simply swap the order of ops
                         # make and insert new nodes
                         new_matmul = oh.make_node(
-                            "MatMul", [start_name, matmul_weight_name], [middle_name]
+                            "MatMul",
+                            [start_name, matmul_weight_name],
+                            [middle_name],
+                            name=consumer.name,
                         )
                         new_mul = oh.make_node(
-                            "Mul", [middle_name, mul_weight_name], [end_name]
+                            "Mul",
+                            [middle_name, mul_weight_name],
+                            [end_name],
+                            name=n.name,
                         )
                         graph.node.insert(node_ind, new_matmul)
                         graph.node.insert(node_ind + 1, new_mul)
@@ -135,9 +163,17 @@ class MoveScalarAddPastMatMul(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Add":
+            if (
+                n.op_type == "Add"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "MatMul":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "MatMul"
+                    and not model.is_join_node(consumer)
+                ):
                     add_weight_name = n.input[1]
                     matmul_weight_name = consumer.input[1]
                     A = model.get_initializer(add_weight_name)
@@ -155,10 +191,16 @@ class MoveScalarAddPastMatMul(Transformation):
                         # update the add weight
                         model.set_initializer(add_weight_name, Anew)
                         new_matmul = oh.make_node(
-                            "MatMul", [start_name, matmul_weight_name], [middle_name]
+                            "MatMul",
+                            [start_name, matmul_weight_name],
+                            [middle_name],
+                            name=consumer.name,
                         )
                         new_add = oh.make_node(
-                            "Add", [middle_name, add_weight_name], [end_name]
+                            "Add",
+                            [middle_name, add_weight_name],
+                            [end_name],
+                            name=n.name,
                         )
                         graph.node.insert(node_ind, new_matmul)
                         graph.node.insert(node_ind + 1, new_add)
@@ -181,9 +223,17 @@ class MoveScalarAddPastConv(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Add":
+            if (
+                n.op_type == "Add"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "Conv":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "Conv"
+                    and not model.is_join_node(consumer)
+                ):
                     conv_node = consumer
                     add_node = n
                     add_weight_name = n.input[1]
@@ -238,9 +288,17 @@ class MoveScalarMulPastConv(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Mul":
+            if (
+                n.op_type == "Mul"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "Conv":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "Conv"
+                    and not model.is_join_node(consumer)
+                ):
                     mul_weight_name = n.input[1]
                     A = model.get_initializer(mul_weight_name)
                     assert A is not None, "Initializer for mul weights is not set."
diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py
index d1da6934a5db07aabe41a9ca40b5de497b6460a1..4bd9385536bc6721c66726169dfa4c69e5f06772 100644
--- a/tests/core/test_modelwrapper.py
+++ b/tests/core/test_modelwrapper.py
@@ -127,3 +127,45 @@ def test_modelwrapper_graph_order():
     assert model.get_node_index(Round_node) == 1
     assert model.get_node_index(Ceil_node) == 2
     assert model.get_node_index(Add_node) == 3
+
+
+def test_modelwrapper_detect_forks_n_joins():
+    # create small network with properties to be tested
+    Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"])
+    Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"])
+
+    Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"])
+    Add_node = onnx.helper.make_node(
+        "Add", inputs=["round1", "ceil1"], outputs=["out1"]
+    )
+
+    in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
+    out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
+
+    graph = onnx.helper.make_graph(
+        nodes=[Neg_node, Round_node, Ceil_node, Add_node],
+        name="simple_graph",
+        inputs=[in1],
+        outputs=[out1],
+        value_info=[
+            onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]),
+            onnx.helper.make_tensor_value_info(
+                "round1", onnx.TensorProto.FLOAT, [4, 4]
+            ),
+            onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]),
+        ],
+    )
+
+    onnx_model = onnx.helper.make_model(graph, producer_name="simple-model")
+    model = ModelWrapper(onnx_model)
+
+    # test
+    assert model.is_fork_node(Neg_node)
+    assert not model.is_fork_node(Round_node)
+    assert not model.is_fork_node(Ceil_node)
+    assert not model.is_fork_node(Add_node)
+
+    assert not model.is_join_node(Neg_node)
+    assert not model.is_join_node(Round_node)
+    assert not model.is_join_node(Ceil_node)
+    assert model.is_join_node(Add_node)
diff --git a/tests/transformation/test_move_add_past_mul.py b/tests/transformation/test_move_add_past_mul.py
index a0516d6fb2ff985fc112185ce99ad8facd841caf..163b9d310a5f12bd0b854f9aa46f53a549bf109e 100644
--- a/tests/transformation/test_move_add_past_mul.py
+++ b/tests/transformation/test_move_add_past_mul.py
@@ -60,6 +60,9 @@ def test_move_add_past_mul_single():
     new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Add"
+    assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
 
 
 def test_move_add_past_mul_multi():
@@ -92,3 +95,50 @@ def test_move_add_past_mul_multi():
     new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Mul"
+    assert new_model.graph.node[2].op_type == "Add"
+    assert new_model.graph.node[3].op_type == "Add"
+    for i in range(len(new_model.graph.node) - 1):
+        assert new_model.graph.node[i].output[0] == new_model.graph.node[i + 1].input[0]
+
+
+def test_move_add_past_mul_only_if_linear():
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2])
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2])
+
+    value_info = [oh.make_tensor_value_info("add1_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul3_param", TensorProto.FLOAT, [1])]
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node("Add", ["top_in", "add1_param"], ["t1"]),
+                oh.make_node("Mul", ["t1", "mul1_param"], ["fork"]),
+                oh.make_node("Mul", ["fork", "mul2_param"], ["t3"]),
+                oh.make_node("Add", ["t3", "fork"], ["t4"]),
+                oh.make_node("Mul", ["t4", "mul3_param"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("add1_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul1_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul2_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul3_param", np.random.rand(2).astype(np.float32))
+    new_model = model.transform(MoveAddPastMul())
+    inp_dict = {"top_in": np.random.rand(2).astype(np.float32)}
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Add"
+    assert new_model.graph.node[2].op_type == "Mul"
+    assert new_model.graph.node[3].op_type == "Add"
+    assert new_model.graph.node[4].op_type == "Mul"
diff --git a/tests/transformation/test_move_scalar_past_conv.py b/tests/transformation/test_move_scalar_past_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..9992d17b96ab5f419f3ac495f126ddfa736349a2
--- /dev/null
+++ b/tests/transformation/test_move_scalar_past_conv.py
@@ -0,0 +1,87 @@
+import numpy as np
+import onnx.helper as oh
+import pytest
+from onnx import TensorProto
+
+import finn.core.onnx_exec as ox
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import (
+    MoveScalarAddPastConv,
+    MoveScalarMulPastConv,
+)
+
+
+@pytest.mark.parametrize(
+    "test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
+)
+def test_move_scalar_past_conv_only_if_linear(test_args):
+    scalar_op = test_args[0]
+    transf_fxn = test_args[1]
+
+    in_feature_dim = 7
+    in_chn = 1
+    padding = False
+    stages = 3
+    kernel_size = 3
+
+    out_feature_dim = (
+        in_feature_dim if padding else in_feature_dim - (kernel_size // 2 * 2) * stages
+    )
+
+    input_shape = [1, in_chn, in_feature_dim, in_feature_dim]
+    output_shape = [1, in_chn, out_feature_dim, out_feature_dim]
+
+    conv_param_shape = [in_chn, in_chn, kernel_size, kernel_size]
+
+    conv_config = {}
+    conv_config["dilations"] = [1, 1]
+    conv_config["group"] = 1
+    conv_config["kernel_shape"] = [kernel_size, kernel_size]
+    conv_config["pads"] = [0, 0, 0, 0]
+    conv_config["strides"] = [1, 1]
+
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
+
+    value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, conv_param_shape)]
+    value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, conv_param_shape)]
+    value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, conv_param_shape)]
+    value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, conv_param_shape)]
+
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node("Conv", ["top_in", "p2"], ["t1"], **conv_config),
+                oh.make_node(scalar_op, ["t1", "p1"], ["t2"]),
+                oh.make_node("Conv", ["t2", "p3"], ["t3"], **conv_config),
+                oh.make_node("Conv", ["t2", "p4"], ["t4"], **conv_config),
+                oh.make_node(scalar_op, ["t3", "t4"], ["t5"]),
+                oh.make_node("Conv", ["t5", "p5"], ["top_out"], **conv_config),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p2", np.random.rand(*conv_param_shape).astype(np.float32))
+    model.set_initializer("p3", np.random.rand(*conv_param_shape).astype(np.float32))
+    model.set_initializer("p4", np.random.rand(*conv_param_shape).astype(np.float32))
+    model.set_initializer("p5", np.random.rand(*conv_param_shape).astype(np.float32))
+    new_model = model.transform(transf_fxn)
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Conv"
+    assert new_model.graph.node[1].op_type == scalar_op
+    assert new_model.graph.node[2].op_type == "Conv"
+    assert new_model.graph.node[3].op_type == "Conv"
+    assert new_model.graph.node[4].op_type == scalar_op
+    assert new_model.graph.node[5].op_type == "Conv"
diff --git a/tests/transformation/test_move_scalar_past_matmul.py b/tests/transformation/test_move_scalar_past_matmul.py
index 896527e82d8cfa869cb979d1102904c70703a14c..e432dbf4ec1a38551609e5914e2d44968a020908 100644
--- a/tests/transformation/test_move_scalar_past_matmul.py
+++ b/tests/transformation/test_move_scalar_past_matmul.py
@@ -27,6 +27,7 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import numpy as np
+import pytest
 import onnx.helper as oh
 from onnx import TensorProto
 
@@ -99,3 +100,56 @@ def test_move_scalar_add_past_matmul():
     assert new_model.graph.node[0].op_type == "MatMul"
     assert new_model.graph.node[1].op_type == "Add"
     assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
+
+
+@pytest.mark.parametrize(
+    "test_args",
+    [("Add", MoveScalarAddPastMatMul()), ("Mul", MoveScalarMulPastMatMul())],
+)
+def test_move_scalar_past_matmul_only_if_linear(test_args):
+    scalar_op = test_args[0]
+    transf_fxn = test_args[1]
+    input_shape = [1, 2]
+    matmul_shape = [2, 2]
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
+
+    p1 = oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1, 1])
+    p2 = oh.make_tensor_value_info("p2", TensorProto.FLOAT, matmul_shape)
+    p3 = oh.make_tensor_value_info("p3", TensorProto.FLOAT, matmul_shape)
+    p4 = oh.make_tensor_value_info("p4", TensorProto.FLOAT, matmul_shape)
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=[p1, p2, p3, p4],
+            nodes=[
+                oh.make_node(scalar_op, ["top_in", "p1"], ["t1"]),
+                oh.make_node("MatMul", ["t1", "p2"], ["fork"]),
+                oh.make_node("MatMul", ["fork", "p3"], ["t3"]),
+                oh.make_node(scalar_op, ["t3", "fork"], ["t4"]),
+                oh.make_node("MatMul", ["t4", "p4"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("p1", np.random.rand(1, 1).astype(np.float32))
+    model.set_initializer("p2", np.random.rand(*matmul_shape).astype(np.float32))
+    model.set_initializer("p3", np.random.rand(*matmul_shape).astype(np.float32))
+    model.set_initializer("p4", np.random.rand(*matmul_shape).astype(np.float32))
+
+    # Transform
+    new_model = model.transform(transf_fxn)
+
+    # Test
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "MatMul"
+    assert new_model.graph.node[1].op_type == scalar_op
+    assert new_model.graph.node[2].op_type == "MatMul"
+    assert new_model.graph.node[3].op_type == scalar_op
+    assert new_model.graph.node[4].op_type == "MatMul"