diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index b23f9f14909a5bd93ae24b34ef65304dafc7e0c1..7163a95c4dbbe5c8bcee4ebeea87c5e9611c179e 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -40,6 +40,7 @@ from finn.core.datatype import DataType
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
+from finn.transformation.general import SortGraph
 
 
 class MoveAddPastMul(Transformation):
@@ -1039,3 +1040,77 @@ class MoveTransposePastScalarMul(Transformation):
             model = model.transform(InferDataLayouts())
             model = model.transform(InferShapes())
         return (model, graph_modified)
+
+
+class MoveIdenticalOpPastJoinOp(Transformation):
+    """
+    Move identical operations on different branches past the common join node.
+    This transformation assumes that the identical operations only change the
+    data layout. For linear operations, see the transformation MoveLinearPastEltwiseAdd.
+    Specifically, this transformation matches and transforms the following patterns:
+    f(x) + f(y) -> f(x + y)
+    where f(.) is currently only supporting 'Transpose', and an 'Add' node is
+    the join node.
+    """
+
+    def __init__(self, identical_op_list, join_node_list):
+        super().__init__()
+        self.ops_to_move = identical_op_list
+        self.join_node_op = join_node_list
+
+    def move_node(self, model, n, prod0, prod1):
+        # Found! move one of the identical_ops to output, remove the other one
+        identical_op0_in0 = prod0.input[0]
+        identical_op1_in0 = prod1.input[0]
+        add_in0 = n.input[0]
+        add_out = n.output[0]
+
+        # Rewire
+        n.input[0] = identical_op0_in0
+        n.input[1] = identical_op1_in0
+
+        # Output tensor of the join node must have the same shape as
+        # its input tensor (original shape is preserved)
+        new_shape = model.get_tensor_shape(identical_op0_in0)
+
+        # Set new tensor shape
+        model.set_tensor_shape(tensor_name=add_in0, tensor_shape=new_shape)
+
+        n.output[0] = add_in0
+        prod0.input[0] = add_in0
+        prod0.output[0] = add_out
+
+        model.graph.node.remove(prod1)
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        for n in graph.node:
+            if n.op_type in self.join_node_op and model.is_join_node(n):
+                in0 = n.input[0]
+                in1 = n.input[1]
+                if in0 is None or in1 is None:
+                    continue
+
+                prod0 = model.find_producer(in0)
+                prod1 = model.find_producer(in1)
+                # Checks if the join node is preceded by
+                # two different, but identical operations
+                if prod0 == prod1:
+                    continue
+
+                identical_op = prod0.op_type == prod1.op_type
+
+                if identical_op and prod0.op_type in self.ops_to_move:
+                    self.move_node(model, n, prod0, prod1)
+                    graph_modified = True
+
+        if graph_modified:
+            model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False)
+
+        return (model, graph_modified)
+
+
+class MoveTransposePastJoinAdd(MoveIdenticalOpPastJoinOp):
+    def __init__(self):
+        super().__init__(["Transpose"], ["Add"])
diff --git a/tests/transformation/streamline/test_move_identical_op_past_join_op.py b/tests/transformation/streamline/test_move_identical_op_past_join_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..94eb52835b1800a839e5a9792e9cf1d7be1e681d
--- /dev/null
+++ b/tests/transformation/streamline/test_move_identical_op_past_join_op.py
@@ -0,0 +1,94 @@
+import pytest
+
+from onnx import helper as oh
+from onnx import TensorProto
+
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.streamline.reorder import MoveTransposePastJoinAdd
+from finn.util.basic import gen_finn_dt_tensor
+import finn.core.onnx_exec as oxe
+
+
+def create_model(perm):
+    if perm == [0, 3, 1, 2]:
+        in_shape = [1, 128, 1, 256]
+        out_shape = [1, 256, 128, 1]
+    if perm == [0, 2, 3, 1]:
+        in_shape = [1, 256, 128, 1]
+        out_shape = [1, 128, 1, 256]
+
+    Transpose1_node = oh.make_node(
+        "Transpose", inputs=["in_transpose1"], outputs=["out_transpose1"], perm=perm
+    )
+
+    Transpose2_node = oh.make_node(
+        "Transpose", inputs=["in_transpose2"], outputs=["out_transpose2"], perm=perm
+    )
+
+    Join1_node = oh.make_node(
+        "Add", inputs=["out_transpose1", "out_transpose2"], outputs=["out_join1"]
+    )
+
+    in_transpose1 = oh.make_tensor_value_info(
+        "in_transpose1", TensorProto.FLOAT, in_shape
+    )
+    in_transpose2 = oh.make_tensor_value_info(
+        "in_transpose2", TensorProto.FLOAT, in_shape
+    )
+    out_transpose1 = oh.make_tensor_value_info(
+        "out_transpose1", TensorProto.FLOAT, out_shape
+    )
+    out_transpose2 = oh.make_tensor_value_info(
+        "out_transpose2", TensorProto.FLOAT, out_shape
+    )
+    out_join1 = oh.make_tensor_value_info("out_join1", TensorProto.FLOAT, out_shape)
+
+    graph = oh.make_graph(
+        nodes=[Transpose1_node, Transpose2_node, Join1_node],
+        name="test_graph",
+        inputs=[in_transpose1, in_transpose2],
+        outputs=[out_join1],
+        value_info=[
+            out_transpose1,
+            out_transpose2,
+        ],
+    )
+
+    onnx_model = oh.make_model(graph, producer_name="test_model")
+    model = ModelWrapper(onnx_model)
+
+    return model
+
+
+# Permutation of transpose node
+@pytest.mark.parametrize("perm", [[0, 3, 1, 2], [0, 2, 3, 1]])
+def test_move_identical_op_past_join_op(perm):
+    model = create_model(perm)
+
+    # Create input data
+    input0_tensor_name = model.graph.input[0].name
+    input1_tensor_name = model.graph.input[1].name
+
+    # Note: it is assumed that both tensors have the same shape and data type
+    input_shape = model.get_tensor_shape(input0_tensor_name)
+    input_dtype = model.get_tensor_datatype(input0_tensor_name)
+    input_val = gen_finn_dt_tensor(input_dtype, input_shape)
+    input_dict = {}
+    input_dict[input0_tensor_name] = input_val
+    input_dict[input1_tensor_name] = input_val
+
+    model_transformed = model.transform(MoveTransposePastJoinAdd())
+
+    assert oxe.compare_execution(model, model_transformed, input_dict)
+
+    # Check if order changed
+    node0_input0_model = model.find_consumers(model.graph.input[0].name)[0].op_type
+    node1_input1_model = model.find_consumers(model.graph.input[1].name)[0].op_type
+    node0_input0_model_transformed = model_transformed.find_consumers(
+        model_transformed.graph.input[0].name
+    )[0].op_type
+    node1_input1_model_transformed = model_transformed.find_consumers(
+        model_transformed.graph.input[1].name
+    )[0].op_type
+    assert node0_input0_model != node0_input0_model_transformed
+    assert node1_input1_model != node1_input1_model_transformed