diff --git a/src/finn/transformation/streamline/collapse_repeated.py b/src/finn/transformation/streamline/collapse_repeated.py
index aa059747b602bc6b659bc8b53b1f18988bba1ef0..67824ad4f633983b93e3178d03118927a1ddd85b 100644
--- a/src/finn/transformation/streamline/collapse_repeated.py
+++ b/src/finn/transformation/streamline/collapse_repeated.py
@@ -48,9 +48,17 @@ class CollapseRepeatedOp(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == self.op_name:
+            if (
+                n.op_type == self.op_name
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == self.op_name:
+                if (
+                    consumer is not None
+                    and consumer.op_type == self.op_name
+                    and not model.is_join_node(consumer)
+                ):
                     op0_param_name = n.input[1]
                     op1_param_name = consumer.input[1]
                     op0_param = model.get_initializer(op0_param_name)
diff --git a/tests/transformation/test_collapse_repeated_op.py b/tests/transformation/test_collapse_repeated_op.py
index 01d932ece0be4b0beb7ad6094284ec3efb1e525e..b74d868f9b921c35ff9f596c811583f45f761374 100644
--- a/tests/transformation/test_collapse_repeated_op.py
+++ b/tests/transformation/test_collapse_repeated_op.py
@@ -34,6 +34,7 @@ import finn.core.onnx_exec as ox
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul
+import pytest
 
 
 def test_collapse_repeated_op():
@@ -67,3 +68,60 @@ def test_collapse_repeated_op():
     new_model = new_model.transform(CollapseRepeatedMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert len(new_model.graph.node) == 2
+    assert new_model.graph.node[0].op_type == "Add"
+    assert new_model.graph.node[1].op_type == "Mul"
+
+
+@pytest.mark.parametrize(
+    "test_args", [("Add", CollapseRepeatedAdd()), ("Mul", CollapseRepeatedMul())],
+)
+def test_collapse_repeated_only_if_linear(test_args):
+    scalar_op = test_args[0]
+    transf_fxn = test_args[1]
+
+    input_shape = [4, 4]
+    output_shape = input_shape
+
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
+
+    value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, [1])]
+
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node(scalar_op, ["top_in", "p2"], ["t1"]),
+                oh.make_node(scalar_op, ["t1", "p1"], ["t2"]),
+                oh.make_node(scalar_op, ["t2", "p3"], ["t3"]),
+                oh.make_node(scalar_op, ["t2", "p4"], ["t4"]),
+                oh.make_node(scalar_op, ["t3", "t4"], ["t5"]),
+                oh.make_node(scalar_op, ["t5", "p5"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p2", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p3", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p4", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p5", *np.random.rand(1).astype(np.float32))
+
+    # Transform
+    new_model = model.transform(transf_fxn)
+
+    # Test
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert len(new_model.graph.node) == 5