diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index dbcf97361017144174f9fbfca35a84361b5abd26..f426a2f631e3fd65bb3bb6a7efc8662e0375a615 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -28,8 +28,10 @@
 
 import numpy as np
 from onnx import helper as oh
+import warnings
 
 from finn.core.datatype import DataType
+import finn.core.data_layout as DataLayout
 from finn.transformation import Transformation
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
@@ -290,3 +292,65 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
         if graph_modified:
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
+
+
+class AbsorbTransposeIntoFlatten(Transformation):
+    """Absorb transpose node into succeeding flatten node, if H=W=1 and the first
+    dimension stays the same. Can also be applied if flatten is implemented implicitly
+    by a reshape node with shape [1, -1] and the first input dimension is 1"""
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "Reshape"
+                and (model.get_initializer(n.input[1]) == [1, -1]).all()
+            ) or n.op_type == "Flatten":
+                prod = model.find_producer(n.input[0])
+                if (
+                    prod is not None
+                    and prod.op_type == "Transpose"
+                    # we ensure that the first dimension is not changed from the
+                    # transpose operation
+                    and get_by_name(prod.attribute, "perm").ints[0] == 0
+                ):
+                    data_layout = model.get_tensor_layout(prod.input[0])
+                    # check for the data layout to interpret input shape correctly
+                    if data_layout is None:
+                        warnings.warn(
+                            """Datalayout for input tensor of transpose node is not set.
+                                To use AbsorbTransposeIntoFlatten transformation
+                                please set tensor data layout."""
+                        )
+                        continue
+                    elif data_layout == DataLayout.NCHW:
+                        (b, c, h, w) = model.get_tensor_shape(prod.input[0])
+                        # if h=w=1 the transposition can be absorbed, otherwise
+                        # the absorption would lead to an error in the behavior
+                        if h != 1 or w != 1:
+                            continue
+                        # the flatten node from onnx keeps by default the first
+                        # dim and flattens the rest, that is why this transformation
+                        # can only works with b != 1 if the model contains already a
+                        # flatten node and not a reshape node with shape = [1, -1].
+                        # If the first  dim of the input tensor is not 1, flatten and
+                        # reshape (with shape = [1, -1]) would lead to different results
+                        if n.op_type == "Reshape" and b != 1:
+                            continue
+                    elif data_layout == DataLayout.NHWC:
+                        (b, h, w, c) = model.get_tensor_shape(prod.input[0])
+                        if h != 1 or w != 1:
+                            continue
+                        if n.op_type == "Reshape" and b != 1:
+                            continue
+                    # create single flatten node and remove obsolete nodes
+                    node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]])
+                    graph.node.remove(n)
+                    graph.node.remove(prod)
+                    graph.node.insert(node_ind, node)
+                    graph_modified = True
+        model = model.transform(InferDataTypes())
+        return (model, graph_modified)