diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index dc01eea411fc1f640e481c9be02a92acdd59533f..f089275c221f769daace3e9628a00bf87b4e5457 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -31,6 +31,7 @@ from onnx import helper as oh
 import warnings
 
 from finn.core.datatype import DataType
+import finn.core.data_layout as DataLayout
 from finn.transformation import Transformation
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
@@ -357,7 +358,68 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
 
+class AbsorbTransposeIntoFlatten(Transformation):
+    """Absorb transpose node into succeeding flatten node, if H=W=1 and the first
+    dimension stays the same. Can also be applied if flatten is implemented implicitly
+    by a reshape node with shape [1, -1] and the first input dimension is 1"""
 
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "Reshape"
+                and (model.get_initializer(n.input[1]) == [1, -1]).all()
+            ) or n.op_type == "Flatten":
+                prod = model.find_producer(n.input[0])
+                if (
+                    prod is not None
+                    and prod.op_type == "Transpose"
+                    # we ensure that the first dimension is not changed from the
+                    # transpose operation
+                    and get_by_name(prod.attribute, "perm").ints[0] == 0
+                ):
+                    data_layout = model.get_tensor_layout(prod.input[0])
+                    # check for the data layout to interpret input shape correctly
+                    if data_layout is None:
+                        warnings.warn(
+                            """Data layout for input tensor of Transpose node is not set.
+                                To use AbsorbTransposeIntoFlatten transformation
+                                please set tensor data layout."""
+                        )
+                        continue
+                    elif data_layout == DataLayout.NCHW:
+                        (b, c, h, w) = model.get_tensor_shape(prod.input[0])
+                        # if h=w=1 the transposition can be absorbed, otherwise
+                        # the absorption would lead to an error in the behavior
+                        if h != 1 or w != 1:
+                            continue
+                        # the flatten node from onnx keeps by default the first
+                        # dim and flattens the rest, that is why this transformation
+                        # can only work with b != 1 if the model contains already a
+                        # flatten node and not a reshape node with shape = [1, -1].
+                        # If the first  dim of the input tensor is not 1, flatten and
+                        # reshape (with shape = [1, -1]) would lead to different results
+                        if n.op_type == "Reshape" and b != 1:
+                            continue
+                    elif data_layout == DataLayout.NHWC:
+                        (b, h, w, c) = model.get_tensor_shape(prod.input[0])
+                        if h != 1 or w != 1:
+                            continue
+                        if n.op_type == "Reshape" and b != 1:
+                            continue
+                    # create single flatten node and remove obsolete nodes
+                    node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]])
+                    graph.node.remove(n)
+                    graph.node.remove(prod)
+                    graph.node.insert(node_ind, node)
+                    graph_modified = True
+        if graph_modified:
+          model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+      
 class AbsorbScalarMulIntoTopK(Transformation):
     """Absorb a mul node into a suceeding topk node if the mul is scalar."""
 
diff --git a/tests/transformation/test_absorb_transp_into_flatten.py b/tests/transformation/test_absorb_transp_into_flatten.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbfa15277717c554da01e38608601997407803b2
--- /dev/null
+++ b/tests/transformation/test_absorb_transp_into_flatten.py
@@ -0,0 +1,99 @@
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+
+from finn.core.modelwrapper import ModelWrapper
+import finn.core.data_layout as DataLayout
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_data_layouts import InferDataLayouts
+from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
+from finn.transformation.streamline.absorb import AbsorbTransposeIntoFlatten
+import finn.core.onnx_exec as oxe
+
+# permutation of transpose node
+@pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 1, 3, 2], [3, 2, 0, 1]])
+# reshape or flatten
+@pytest.mark.parametrize("shape", [None, [1, -1], [-1, 1]])
+# input shape
+@pytest.mark.parametrize("ishape", [[1, 1, 1, 4], [2, 4, 1, 1], [1, 2, 2, 4]])
+# datalayout
+@pytest.mark.parametrize("data_layout", ["NCHW", "NHWC"])
+def test_absorb_transp_into_flatten(perm, shape, ishape, data_layout):
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
+    transp_node = helper.make_node("Transpose", ["inp"], ["transp_out"], perm=perm)
+    dummy_in = np.random.uniform(low=0, high=1, size=tuple(ishape)).astype(np.float32)
+    if shape is None:
+        shape_node = helper.make_node("Flatten", ["transp_out"], ["outp"])
+        dummy_in = dummy_in.transpose(tuple(perm))
+        oshape = dummy_in.reshape(dummy_in.shape[0], -1).shape
+        outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
+        shape0 = None
+    else:
+        shape0 = helper.make_tensor_value_info("shape0", TensorProto.FLOAT, shape)
+        shape_node = helper.make_node("Reshape", ["transp_out", "shape0"], ["outp"])
+        oshape = dummy_in.transpose(tuple(perm)).reshape(tuple(shape)).shape
+        outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
+
+    graph = helper.make_graph(
+        nodes=[transp_node, shape_node],
+        name="absorb-transpose-graph",
+        inputs=[inp],
+        outputs=[outp],
+    )
+
+    model = helper.make_model(graph, producer_name="absorb_transpose_model")
+    model = ModelWrapper(model)
+    if shape is not None:
+        model.graph.value_info.append(shape0)
+        model.set_initializer("shape0", np.asarray(shape))
+    if data_layout == "NCHW":
+        model.set_tensor_layout("inp", DataLayout.NCHW)
+    else:
+        model.set_tensor_layout("inp", DataLayout.NHWC)
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+    model = model.transform(InferDataLayouts())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
+    model.save("test.onnx")
+    model_transformed = model.transform(AbsorbTransposeIntoFlatten())
+    model_transformed.save("test2.onnx")
+
+    # verify transformation
+    inp_values = np.random.uniform(low=-1, high=1, size=tuple(ishape)).astype(
+        np.float32
+    )
+    idict = {model.graph.input[0].name: inp_values}
+    assert oxe.compare_execution(model, model_transformed, idict)
+
+    # only some of the parameter combinations lead to a graph that will be changed when
+    # AbsorbTransposeIntoFlatten is applied
+
+    if shape == [-1, 1]:  # not a flatten operation, so the graph will not be changed
+        assert model.graph == model_transformed.graph
+
+    elif perm == [
+        3,
+        2,
+        0,
+        1,
+    ]:  # the first dimension is also part of the transpose operation
+        # so the graph will not be changed
+        assert model.graph == model_transformed.graph
+
+    # the following cases are the ones in which the model is transformed
+    # because we tested the parameters shape and perm befire we can only consider ishape
+    # and data_layout (the transformed model should only contain a "Flatten" node)
+    elif ishape == [1, 1, 1, 4] and data_layout == "NHWC":
+        assert model_transformed.graph.node[0].op_type == "Flatten"
+
+    elif ishape == [2, 4, 1, 1] and data_layout == "NCHW" and shape is None:
+        # If the first  dimension of the input tensor is not 1, flatten and
+        # reshape (with shape = [1, -1]) would lead to different results
+        assert model_transformed.graph.node[0].op_type == "Flatten"
+
+    # all other cases lead to an unchanged model
+    else:
+        assert model.graph == model_transformed.graph