diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 0299c4f4d89d1fdd94434db77c77a0e529c86d26..a983e67750a0a860eeeb4b429f7d6b181fc84fe3 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -473,7 +473,7 @@ class AbsorbConsecutiveTransposes(Transformation):
     """Remove (Transpose -> Transpose) patterns when the input and output
     of the pattern have the same layout."""
 
-    def Are_opposite_permutations(self, perms1, perms2):
+    def are_opposite_permutations(self, perms1, perms2):
         if len(perms1) != len(perms2):
             return False
         assert 0 <= max(perms2) < len(perms2), "invalid permutation"
@@ -488,72 +488,40 @@ class AbsorbConsecutiveTransposes(Transformation):
     def apply(self, model):
         graph = model.graph
         graph_modified = False
-        for n in graph.node:
-            if n.op_type == "Transpose":
-                if model.is_fork_node(n):
-                    next_nodes = model.find_direct_successors(n)
-                    perms1 = list(get_by_name(n.attribute, "perm").ints)
-
-                    # check if all nodes after fork are opposite transposes
-                    all_opposite_transposes = True
-                    for next_node in next_nodes:
-                        if next_node is not None and next_node.op_type == "Transpose":
-                            perms2 = list(get_by_name(next_node.attribute, "perm").ints)
-                            if not self.Are_opposite_permutations(perms1, perms2):
-                                all_opposite_transposes = False
-                                break
-                        else:
-                            all_opposite_transposes = False
-                            break
-
-                    if not all_opposite_transposes:
-                        continue
-
-                    prod = model.find_producer(n.input[0])
-                    for next_node in next_nodes:
-                        # connect next_node's consumer input to n's producer output
-                        # TODO implement this to allow for forks as producers and
-                        # joins as consumers
-                        cons = model.find_consumer(next_node.output[0])
-                        cons.input[0] = prod.output[0]
-
-                        # remove consumer transpose
-                        graph.node.remove(next_node)
-
-                    # remove producer transpose
-                    graph.node.remove(n)
-                    graph_modified = True
-
-                else:
-                    next_node = model.find_consumer(n.output[0])
+        for node in graph.node:
+            if node.op_type == "Transpose":
+                next_nodes = model.find_consumers(node.output[0])
+                perms1 = list(get_by_name(node.attribute, "perm").ints)
+                # check if all nodes after fork are opposite transposes
+                all_opposite_transposes = True
+                for next_node in next_nodes:
                     if next_node is not None and next_node.op_type == "Transpose":
-                        perms1 = list(get_by_name(n.attribute, "perm").ints)
                         perms2 = list(get_by_name(next_node.attribute, "perm").ints)
-                        if self.Are_opposite_permutations(perms1, perms2):
-
-                            # connect next_node's consumer input to n's producer output
-                            # TODO implement this to allow for forks as producers
-                            consumers = model.find_direct_successors(next_node)
-                            prod = model.find_producer(n.input[0])
-                            if prod is not None:
-                                for cons in consumers:
-                                    for cons_in in cons.input:
-                                        if cons_in == next_node.output[0]:
-                                            prod.output[0] = cons_in
-                                            break
-                            else:
-                                # n.input[0] is top-level graph input
-                                # wire consumers directly to that
-                                for cons in consumers:
-                                    for i, iname in enumerate(cons.input):
-                                        if iname == next_node.output[0]:
-                                            cons.input[i] = n.input[0]
-
-                            # remove both transposes
-                            graph.node.remove(n)
-                            graph.node.remove(next_node)
+                        if not self.are_opposite_permutations(perms1, perms2):
+                            all_opposite_transposes = False
+                            break
+                    else:
+                        all_opposite_transposes = False
+                        break
+                if not all_opposite_transposes:
+                    continue
+                source_tensor = node.input[0]
+                for next_node in next_nodes:
+                    # connect next_node's consumers' appropriate input to n's input
+                    # TODO how to handle top-level outputs if any?
+                    nextnode_out = next_node.output[0]
+                    assert nextnode_out not in [x.name for x in model.graph.output]
+                    consumers = model.find_consumers(nextnode_out)
+                    for cons in consumers:
+                        for i, iname in enumerate(cons.input):
+                            if iname == nextnode_out:
+                                cons.input[i] = source_tensor
+                    # remove consumer transpose
+                    graph.node.remove(next_node)
+                # remove producer transpose
+                graph.node.remove(node)
+                graph_modified = True
 
-                            graph_modified = True
         if graph_modified:
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 9ff8a2173ce81e2a19c56bbd20a326759c3b9df2..3e815c1537353cc2be970a2068d4ded30cc48bc8 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -553,6 +553,8 @@ class MoveLinearPastEltwiseAdd(Transformation):
                 # Other transform should handle that
                 if prod0 is None or prod1 is None or (prod0 == prod1):
                     continue
+                if len(prod0.input) < 2 or len(prod1.input) < 2:
+                    continue
                 init0 = model.get_initializer(prod0.input[1])
                 init1 = model.get_initializer(prod1.input[1])
                 # if either initializer is None, skip
@@ -728,9 +730,10 @@ class MoveOpPastFork(Transformation):
     can be merged with nodes in the branches
     """
 
-    def __init__(self, op_name_list):
+    def __init__(self, op_name_list, get_attrs_fxn=lambda x: {}):
         super().__init__()
         self.ops_to_move = op_name_list
+        self.get_attrs_fxn = get_attrs_fxn
 
     def apply(self, model):
         graph = model.graph
@@ -747,9 +750,10 @@ class MoveOpPastFork(Transformation):
 
                 # Restrict this transform to operations with constant parameters
                 # Assuming parameters is in input 1
-                op_init_param = model.get_initializer(n.input[1])
-                if op_init_param is None:
-                    continue
+                if len(n.input) > 1:
+                    op_init_param = model.get_initializer(n.input[1])
+                else:
+                    op_init_param = None
 
                 # Check case when branches are empty and go
                 # to the same node
@@ -766,16 +770,20 @@ class MoveOpPastFork(Transformation):
 
                 for consumer_node in consumers[1:]:
                     # create new node
-                    new_param_name = model.make_new_valueinfo_name()
                     new_output_tensor_name = model.make_new_valueinfo_name()
+                    if op_init_param is None:
+                        new_inp_list = [n.input[0]]
+                    else:
+                        new_param_name = model.make_new_valueinfo_name()
+                        new_inp_list = [n.input[0], new_param_name]
+                        model.set_initializer(new_param_name, op_init_param)
+                    attrs = self.get_attrs_fxn(n)
+                    # TODO use copy of original node instead to get attrs?
                     new_node = oh.make_node(
-                        n.op_type,
-                        [n.input[0], new_param_name],
-                        [new_output_tensor_name],
+                        n.op_type, new_inp_list, [new_output_tensor_name], **attrs
                     )
                     graph.node.insert(node_ind, new_node)
                     node_ind += 1
-                    model.set_initializer(new_param_name, op_init_param)
 
                     # change consumer input tensor
                     graph.node.remove(consumer_node)
@@ -811,6 +819,13 @@ class MoveLinearPastFork(MoveOpPastFork):
         super().__init__(["Add", "Mul"])
 
 
+class MoveTransposePastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(
+            ["Transpose"], lambda x: {"perm": get_by_name(x.attribute, "perm").ints}
+        )
+
+
 class MoveMaxPoolPastMultiThreshold(Transformation):
     """Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph."""
 
diff --git a/tests/transformation/streamline/test_absorb_opposite_transposes.py b/tests/transformation/streamline/test_absorb_opposite_transposes.py
index 51ea5edfc420bf935de3e196df1b150934782a91..88cbd5657e2ae6c0946e59186c25d935595ad2ff 100644
--- a/tests/transformation/streamline/test_absorb_opposite_transposes.py
+++ b/tests/transformation/streamline/test_absorb_opposite_transposes.py
@@ -29,8 +29,7 @@
 import pytest
 
 import numpy as np
-import onnx.helper as oh
-from onnx import TensorProto
+import onnx.parser as oprs
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.infer_shapes import InferShapes
 
@@ -41,39 +40,44 @@ from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes
 @pytest.mark.streamline
 def test_absorb_opposite_transposes():
     np.random.seed(0)
-    input_shape = [1, 3, 4, 2]
-    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
-    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
-    value_info = [oh.make_tensor_value_info("add_param_0", TensorProto.FLOAT, [1])]
-    value_info += [oh.make_tensor_value_info("add_param_1", TensorProto.FLOAT, [1])]
-    value_info += [oh.make_tensor_value_info("mul_param_0", TensorProto.FLOAT, [1])]
-    modelproto = oh.make_model(
-        oh.make_graph(
-            name="test",
-            inputs=[top_in],
-            outputs=[top_out],
-            value_info=value_info,
-            nodes=[
-                oh.make_node("Add", ["top_in", "add_param_0"], ["t0"]),
-                oh.make_node("Transpose", ["t0"], ["t1"], perm=[0, 2, 3, 1]),
-                oh.make_node("Transpose", ["t1"], ["t2"], perm=[0, 3, 1, 2]),
-                oh.make_node("Add", ["t2", "add_param_1"], ["t3"]),
-                oh.make_node("Transpose", ["t3"], ["t4"], perm=[0, 2, 3, 1]),
-                oh.make_node("Transpose", ["t4"], ["t5"], perm=[0, 3, 1, 2]),
-                oh.make_node("Add", ["t5", "t2"], ["t6"]),
-                oh.make_node("Mul", ["t6", "mul_param_0"], ["top_out"]),
-            ],
-        )
-    )
-    model = ModelWrapper(modelproto)
+    shp = [1, 3, 4, 2]
+    shp_str = str(shp)
+    input = f"""
+    <
+        ir_version: 7,
+        opset_import: ["" : 9]
+    >
+    agraph (float{shp_str} in0) => (float{shp_str} out0)
+    <
+        float[1] add0_param = {{1.0}},
+        float[1] add1_param = {{3.0}},
+        float[1] mul0_param = {{2.0}}
+    >
+    {{
+        add0_out = Add(in0, add0_param)
+        t0_out = Transpose<perm=[0,2,3,1]>(add0_out)
+        t1_out = Transpose<perm=[0,3,1,2]>(t0_out)
+        add1_out = Add(t1_out, add1_param)
+        t2_out = Transpose<perm=[0,2,3,1]>(add1_out)
+        t3_out = Transpose<perm=[0,3,1,2]>(t2_out)
+        add2_out = Add(t1_out, t3_out)
+        t4_out = Transpose<perm=[0,2,3,1]>(add2_out)
+        t5_out = Transpose<perm=[0,3,1,2]>(t4_out)
+        t6_out = Transpose<perm=[0,3,1,2]>(t4_out)
+        m0_out = Mul(t5_out, mul0_param)
+        m1_out = Mul(t6_out, mul0_param)
+        out0 = Mul(m0_out, m1_out)
+    }}
+    """
+    model = oprs.parse_model(input)
+    model = ModelWrapper(model)
     model = model.transform(InferShapes())
-    model.set_initializer("add_param_0", np.asarray([1], dtype=np.float32))
-    model.set_initializer("add_param_1", np.asarray([3], dtype=np.float32))
-    model.set_initializer("mul_param_0", np.asarray([2], dtype=np.float32))
+    model.save("dbg.onnx")
     new_model = model.transform(AbsorbConsecutiveTransposes())
     new_model = new_model.transform(InferShapes())
-    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+    new_model.save("newdbg.onnx")
+    inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)}
     assert ox.compare_execution(model, model, inp_dict)
-    assert len(new_model.graph.node) == 4
+    assert len(new_model.graph.node) == 6
     for n in new_model.graph.node:
         assert new_model.graph.node[0].op_type != "Transpose"
diff --git a/tests/transformation/streamline/test_move_past_fork.py b/tests/transformation/streamline/test_move_past_fork.py
index 5064fa3fca869a245c87cf0c1680d1357e5de60b..7e77d7f9b3502429f08c40558e330b6261d0dbad 100644
--- a/tests/transformation/streamline/test_move_past_fork.py
+++ b/tests/transformation/streamline/test_move_past_fork.py
@@ -28,80 +28,113 @@
 import pytest
 
 import numpy as np
-from onnx import TensorProto, helper
+import onnx.parser as oprs
 from qonnx.core.modelwrapper import ModelWrapper
+from qonnx.transformation.general import GiveUniqueNodeNames
 from qonnx.transformation.infer_shapes import InferShapes
+from qonnx.util.basic import get_by_name
 
 import finn.core.onnx_exec as oxe
-from finn.transformation.streamline.reorder import MoveLinearPastFork
+from finn.transformation.streamline.reorder import (
+    MoveLinearPastFork,
+    MoveTransposePastFork,
+)
+
+
+@pytest.mark.streamline
+def test_move_past_fork_transpose():
+    shp = [1, 3, 32, 32]
+    shp_str = str(shp)
+    input = f"""
+    <
+        ir_version: 7,
+        opset_import: ["" : 9]
+    >
+    agraph (float{shp_str} in0) => (float{shp_str} out0)
+    {{
+        t0_out = Transpose<perm=[0,2,3,1]>(in0)
+        t1_out = Transpose<perm=[0,3,1,2]>(t0_out)
+        t2_out = Transpose<perm=[0,3,1,2]>(t0_out)
+        out0 = Add(t1_out, t2_out)
+    }}
+    """
+    model = oprs.parse_model(input)
+    model = ModelWrapper(model)
+    model = model.transform(InferShapes())
+    new_model = model.transform(MoveTransposePastFork())
+    new_model = new_model.transform(GiveUniqueNodeNames())
+    nodes = new_model.graph.node
+    assert oxe.compare_execution(
+        model, new_model, {"in0": np.random.rand(*shp).astype(np.float32)}
+    )
+    assert len(nodes) == 5
+    assert not new_model.is_fork_node(get_by_name(nodes, "Transpose_0"))
 
 
 @pytest.mark.streamline
 @pytest.mark.parametrize("ch", [64, 1])
 # ifmdim
 @pytest.mark.parametrize("ifmdim", [-1, 7])
-def test_move_past_fork(ch, ifmdim):
-    # generate test vectors of correct shape
+def test_move_past_fork_linear(ch, ifmdim):
     if ifmdim == -1:
-        input_shape = (1, ch)
+        shp = [1, ch]
     else:
-        input_shape = (1, ch, ifmdim, ifmdim)
+        shp = [1, ch, ifmdim, ifmdim]
+    shp_str = str(shp)
+    input = f"""
+    <
+        ir_version: 7,
+        opset_import: ["" : 9]
+    >
+    agraph (float{shp_str} in0) => (float{shp_str} out0)
+    <
+        float{shp_str} add0_param,
+        float{shp_str} mul_shared_param,
+        float{shp_str} add2_param,
+        float{shp_str} mul2_param,
+        float{shp_str} add3_param,
+        float{shp_str} add4_param,
+        float{shp_str} mul3_param,
+        float{shp_str} add6_param
+    >
+    {{
 
-    top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
-    top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
-
-    num_of_params = 8
-    value_info = []
-    for i in range(num_of_params):
-        value_info += [
-            helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
-        ]
-
-    add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"])
-    mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"])
-    add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"])
-    mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"])
-    modelproto = helper.make_model(
-        helper.make_graph(
-            name="test",
-            inputs=[top_in],
-            outputs=[top_out],
-            value_info=value_info,
-            nodes=[
-                # fork1
-                add_1_to_move,
-                helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
-                helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
-                helper.make_node("Add", ["t2", "t3"], ["t4"]),
-                helper.make_node("Add", ["t4", "p3"], ["t5"]),
-                # fork2
-                mul_1_to_move,
-                add_2_to_move,
-                helper.make_node("Add", ["fork2", "p6"], ["t7"]),
-                helper.make_node("Add", ["t6", "t7"], ["t8"]),
-                # empty branches: do nothing
-                mul_1_not_to_move,
-                helper.make_node("Add", ["fork3", "fork3"], ["top_out"]),
-            ],
-        )
-    )
-    model = ModelWrapper(modelproto)
+        add0_out = Add(in0, add0_param)
+        mul0_out = Mul(add0_out, mul_shared_param)
+        mul1_out = Mul(add0_out, mul_shared_param)
+        add1_out = Add(mul0_out, mul1_out)
+        add2_out = Add(add1_out, add2_param)
+        mul2_out = Mul(add2_out, mul2_param)
+        add3_out = Add(mul2_out, add3_param)
+        add4_out = Add(mul2_out, add4_param)
+        add5_out = Add(add3_out, add4_out)
+        mul3_out = Mul(add5_out, mul3_param)
+        out0 = Add(mul3_out, add6_param)
+    }}
+    """
+    model = oprs.parse_model(input)
+    model = ModelWrapper(model)
     model = model.transform(InferShapes())
 
     np.random.seed(0)
-    for i in range(num_of_params):
-        model.set_initializer(
-            "p" + str(i), np.random.rand(*input_shape).astype(np.float32)
-        )
-
+    for tensor_name in model.get_all_tensor_names():
+        if tensor_name.endswith("_param"):
+            pshape = model.get_tensor_shape(tensor_name)
+            model.set_initializer(
+                tensor_name, np.random.rand(*pshape).astype(np.float32)
+            )
+    model = model.transform(GiveUniqueNodeNames())
     # Transform
     new_model = model.transform(MoveLinearPastFork())
-    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
-
+    new_model = new_model.transform(GiveUniqueNodeNames())
+    inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)}
     # Test
     assert oxe.compare_execution(model, new_model, inp_dict)
-    assert not new_model.is_fork_node(add_1_to_move)
-    assert not new_model.is_fork_node(mul_1_to_move)
-    assert not new_model.is_fork_node(add_2_to_move)
-    assert new_model.is_fork_node(mul_1_not_to_move)
+    nodes = new_model.graph.node
+    assert len(new_model.get_nodes_by_op_type("Add")) == 9
+    assert len(new_model.get_nodes_by_op_type("Mul")) == 5
+    assert not new_model.is_fork_node(get_by_name(nodes, "Add_0"))
+    assert new_model.is_join_node(get_by_name(nodes, "Add_2"))
+    assert not new_model.is_fork_node(get_by_name(nodes, "Mul_2"))
+    assert not new_model.is_join_node(get_by_name(nodes, "Add_5"))
     assert len(new_model.graph.node) == 14