diff --git a/tests/transformation/test_absorb_transp_into_flatten.py b/tests/transformation/test_absorb_transp_into_flatten.py
new file mode 100644
index 0000000000000000000000000000000000000000..609f28139f30a19fe54f857c2fa14018a4b2211d
--- /dev/null
+++ b/tests/transformation/test_absorb_transp_into_flatten.py
@@ -0,0 +1,99 @@
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+
+from finn.core.modelwrapper import ModelWrapper
+import finn.core.data_layout as DataLayout
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_data_layouts import InferDataLayouts
+from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
+from finn.transformation.streamline.absorb import AbsorbTransposeIntoFlatten
+import finn.core.onnx_exec as oxe
+
+# permutation of transpose node
+@pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 1, 3, 2], [3, 2, 0, 1]])
+# reshape or flatten
+@pytest.mark.parametrize("shape", [None, [1, -1], [-1, 1]])
+# input shape
+@pytest.mark.parametrize("ishape", [[1, 1, 1, 4], [2, 4, 1, 1], [1, 2, 2, 4]])
+# datalayout
+@pytest.mark.parametrize("data_layout", ["NCHW", "NHWC"])
+def test_absorb_transp_into_flatten(perm, shape, ishape, data_layout):
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
+    transp_node = helper.make_node("Transpose", ["inp"], ["transp_out"], perm=perm)
+    dummy_in = np.random.uniform(low=0, high=1, size=tuple(ishape)).astype(np.float32)
+    if shape is None:
+        shape_node = helper.make_node("Flatten", ["transp_out"], ["outp"])
+        dummy_in = dummy_in.transpose(tuple(perm))
+        oshape = dummy_in.reshape(dummy_in.shape[0], -1).shape
+        outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
+        shape0 = None
+    else:
+        shape0 = helper.make_tensor_value_info("shape0", TensorProto.FLOAT, shape)
+        shape_node = helper.make_node("Reshape", ["transp_out", "shape0"], ["outp"])
+        oshape = dummy_in.transpose(tuple(perm)).reshape(tuple(shape)).shape
+        outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
+
+    graph = helper.make_graph(
+        nodes=[transp_node, shape_node],
+        name="absorb-transpose-graph",
+        inputs=[inp],
+        outputs=[outp],
+    )
+
+    model = helper.make_model(graph, producer_name="absorb_transpose_model")
+    model = ModelWrapper(model)
+    if shape is not None:
+        model.graph.value_info.append(shape0)
+        model.set_initializer("shape0", np.asarray(shape))
+    if data_layout == "NCHW":
+        model.set_tensor_layout("inp", DataLayout.NCHW)
+    else:
+        model.set_tensor_layout("inp", DataLayout.NHWC)
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+    model = model.transform(InferDataLayouts())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
+    model.save("test.onnx")
+    model_transformed = model.transform(AbsorbTransposeIntoFlatten())
+    model_transformed.save("test2.onnx")
+
+    # verify transformation
+    inp_values = np.random.uniform(low=-1, high=1, size=tuple(ishape)).astype(
+        np.float32
+    )
+    idict = {"inp": inp_values}
+    assert oxe.compare_execution(model, model_transformed, idict)
+
+    # only some of the parameter combinations lead to a graph that will be changed when
+    # AbsorbTransposeIntoFlatten is applied
+
+    if shape == [-1, 1]:  # not a flatten operation, so the graph will not be changed
+        assert model.graph == model_transformed.graph
+
+    elif perm == [
+        3,
+        2,
+        0,
+        1,
+    ]:  # the first dimension is also part of the transpose operation
+        # so the graph will not be changed
+        assert model.graph == model_transformed.graph
+
+    # the following cases are the ones in which the model is transformed
+    # because we tested the parameters shape and perm befire we can only consider ishape
+    # and data_layout (the transformed model should only contain a "Flatten" node)
+    elif ishape == [1, 1, 1, 4] and data_layout == "NHWC":
+        assert model_transformed.graph.node[0].op_type == "Flatten"
+
+    elif ishape == [2, 4, 1, 1] and data_layout == "NCHW" and shape is None:
+        # If the first  dimension of the input tensor is not 1, flatten and
+        # reshape (with shape = [1, -1]) would lead to different results
+        assert model_transformed.graph.node[0].op_type == "Flatten"
+
+    # all other cases lead to an unchanged model
+    else:
+        assert model.graph == model_transformed.graph