diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py
index 4c9b076b6acb1e725421d5067ac74ad42c3dbafb..50cd9ed4ff663044433c5d8997f2fcd9337a0dd6 100644
--- a/src/finn/builder/build_dataflow_steps.py
+++ b/src/finn/builder/build_dataflow_steps.py
@@ -46,10 +46,8 @@ from finn.transformation.streamline import Streamline
 from finn.transformation.infer_data_layouts import InferDataLayouts
 from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
 from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
-from finn.transformation.streamline.reorder import (
-    MakeMaxPoolNHWC,
-    MoveScalarLinearPastInvariants,
-)
+from finn.transformation.streamline.reorder import MakeMaxPoolNHWC
+
 from shutil import copy, copytree
 from finn.transformation.fpgadataflow.insert_dwc import InsertDWC
 from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
@@ -159,13 +157,13 @@ def step_streamline(model: ModelWrapper, cfg: DataflowBuildConfig):
     """
 
     model = model.transform(absorb.AbsorbSignBiasIntoMultiThreshold())
-    model = model.transform(MoveScalarLinearPastInvariants())
     model = model.transform(Streamline())
     need_lowering = len(model.get_nodes_by_op_type("Conv")) > 0
     if need_lowering:
         model = model.transform(LowerConvsToMatMul())
         model = model.transform(MakeMaxPoolNHWC())
         model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
+        model = model.transform(MakeMaxPoolNHWC())
     model = model.transform(ConvertBipolarMatMulToXnorPopcount())
     model = model.transform(Streamline())
     # absorb final add-mul nodes into TopK
@@ -182,7 +180,7 @@ def step_streamline(model: ModelWrapper, cfg: DataflowBuildConfig):
 def step_convert_to_hls(model: ModelWrapper, cfg: DataflowBuildConfig):
     """Convert eligible nodes to `HLSCustomOp` subclasses that represent HLS
     layers. Which nodes and particular configurations can be converted to HLS
-    is limited, see the source code of the `convert_to_hls` module for more. """
+    is limited, see the source code of the `convert_to_hls` module for more."""
 
     mem_mode = cfg.default_mem_mode.value
     if cfg.standalone_thresholds:
diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py
index 876f8892dbc9c42189ee8dc06ff5eb407f7a0946..97cd957ce166255c1442a544f6e865b42c33a1df 100644
--- a/src/finn/transformation/streamline/__init__.py
+++ b/src/finn/transformation/streamline/__init__.py
@@ -60,6 +60,7 @@ from finn.transformation.streamline.reorder import (
     MoveAddPastConv,
     MoveScalarMulPastConv,
     MoveMulPastMaxPool,
+    MoveScalarLinearPastInvariants,
 )
 
 from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
@@ -78,6 +79,7 @@ class Streamline(Transformation):
             BatchNormToAffine(),
             ConvertSignToThres(),
             MoveMulPastMaxPool(),
+            MoveScalarLinearPastInvariants(),
             AbsorbSignBiasIntoMultiThreshold(),
             MoveAddPastMul(),
             MoveScalarAddPastMatMul(),
diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 9b842162f7f751c60b18bbd288ff96ef28d3aa88..9c6472d4f7fcebb6680b5f78b08bb4538d36df8f 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -308,9 +308,8 @@ class Absorb1BitMulIntoConv(Transformation):
 
 
 class AbsorbTransposeIntoMultiThreshold(Transformation):
-    """Change (NHWCTranpose -> MultiThreshold -> NCHWTranspose) to (MultiThreshold)
-    with NHWC mode. For (NHWCTranpose -> MultiThreshold -> Flatten), move Transpose
-    past MultiThreshold to prepare for the RemoveCNVtoFCFlatten() transformation."""
+    """Change (NCHWTranspose -> MultiThreshold -> NHWCTranspose) to (MultiThreshold)
+    with NHWC mode. For (NCHWTranspose -> MultiThreshold), move Transpose past MT."""
 
     def apply(self, model):
         graph = model.graph
@@ -339,35 +338,26 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
                                 graph.node.remove(n)
                                 graph.node.remove(final_t_cand)
                                 graph_modified = True
-                        # also support implicit flatten via reshape, e.g. reshape(1,-1)
-                        elif (
-                            final_t_cand.op_type == "Flatten"
-                            or final_t_cand.op_type == "Reshape"
-                        ):
-                            ishape = model.get_tensor_shape(final_t_cand.input[0])
-                            oshape = model.get_tensor_shape(final_t_cand.output[0])
-                            if len(oshape) == 2 and ishape[0] == oshape[0]:
-                                # transition to FC part, can still use NHWC
-                                mt = getCustomOp(mt_cand)
-                                mt.set_nodeattr("data_layout", "NHWC")
-                                # get rid of first tranpose node
-                                mt_cand.input[0] = n.input[0]
-                                graph.node.remove(n)
-                                # fix output shape for MultiThreshold
-                                mt_ishape = model.get_tensor_shape(mt_cand.input[0])
-                                (b, h, w, c) = mt_ishape
-                                model.set_tensor_shape(mt_cand.output[0], mt_ishape)
-                                # re-insert Transpose behind MultiThreshold
-                                transpose_output = model.make_new_valueinfo_name()
-                                new_transpose = oh.make_node(
-                                    "Transpose",
-                                    [mt_cand.output[0]],
-                                    [transpose_output],
-                                    perm=[0, 3, 1, 2],
-                                )
-                                graph.node.insert(node_ind + 1, new_transpose)
-                                final_t_cand.input[0] = transpose_output
-                                graph_modified = True
+                        else:
+                            mt = getCustomOp(mt_cand)
+                            mt.set_nodeattr("data_layout", "NHWC")
+                            # get rid of first tranpose node
+                            mt_cand.input[0] = n.input[0]
+                            graph.node.remove(n)
+                            # fix output shape for MultiThreshold
+                            mt_ishape = model.get_tensor_shape(mt_cand.input[0])
+                            model.set_tensor_shape(mt_cand.output[0], mt_ishape)
+                            # re-insert Transpose behind MultiThreshold
+                            transpose_output = model.make_new_valueinfo_name()
+                            new_transpose = oh.make_node(
+                                "Transpose",
+                                [mt_cand.output[0]],
+                                [transpose_output],
+                                perm=[0, 3, 1, 2],
+                            )
+                            graph.node.insert(node_ind + 1, new_transpose)
+                            final_t_cand.input[0] = transpose_output
+                            graph_modified = True
         if graph_modified:
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 7163a95c4dbbe5c8bcee4ebeea87c5e9611c179e..4049d7bc8b555dce6f3ab6f9475f6aebf41fbc2c 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -645,7 +645,8 @@ class MoveScalarLinearPastInvariants(Transformation):
 
 
 class MakeMaxPoolNHWC(Transformation):
-    """Convert (MaxPool, NHWCTranpose) into (MaxPoolNHWC)."""
+    """Convert (MaxPool, NHWCTranspose) into (NHWCTranspose, MaxPoolNHWC)
+    and (NCHWTranspose, MaxPool) into (MaxPoolNHWC, NCHWTranspose)."""
 
     def apply(self, model):
         graph = model.graph
@@ -655,6 +656,7 @@ class MakeMaxPoolNHWC(Transformation):
             node_ind += 1
             if n.op_type == "MaxPool":
                 consumer = model.find_consumer(n.output[0])
+                producer = model.find_producer(n.input[0])
                 if consumer is not None and consumer.op_type == "Transpose":
                     perms = list(get_by_name(consumer.attribute, "perm").ints)
                     if perms == [0, 2, 3, 1]:
@@ -674,6 +676,25 @@ class MakeMaxPoolNHWC(Transformation):
                         graph.node.remove(consumer)
                         graph.node.insert(node_ind - 1, consumer)
                         graph_modified = True
+                elif producer is not None and producer.op_type == "Transpose":
+                    perms = list(get_by_name(producer.attribute, "perm").ints)
+                    if perms == [0, 3, 1, 2]:
+                        n.op_type = "MaxPoolNHWC"
+                        n.domain = "finn.custom_op.general"
+                        start_name = producer.input[0]
+                        mid_name = n.input[0]
+                        end_name = n.output[0]
+                        (b, hi, wi, c) = model.get_tensor_shape(start_name)
+                        (b, c, ho, wo) = model.get_tensor_shape(end_name)
+                        producer.input[0] = mid_name
+                        producer.output[0] = end_name
+                        n.input[0] = start_name
+                        n.output[0] = mid_name
+                        model.set_tensor_shape(mid_name, (b, ho, wo, c))
+                        model.set_tensor_shape(end_name, (b, c, ho, wo))
+                        graph.node.remove(producer)
+                        graph.node.insert(node_ind, producer)
+                        graph_modified = True
         return (model, graph_modified)