diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 0b6259a61d3eb67b7b38d4c6939019ce2893a875..b46b82c77a3f1b70a3b05d87cd3c48fc1d94fd45 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -27,12 +27,14 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import numpy as np
+import warnings
 from onnx import helper as oh
 
 from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
+from finn.custom_op.registry import getCustomOp
 
 
 class MoveAddPastMul(Transformation):
@@ -531,3 +533,67 @@ class MoveMulPastFork(MoveOpPastFork):
 class MoveLinearPastFork(MoveOpPastFork):
     def __init__(self):
         super().__init__(["Add", "Mul"])
+
+
+class MoveMaxPoolPastMultiThreshold(Transformation):
+    """Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        nodes = [n for n in graph.node]
+        for n in nodes:
+            node_ind += 1
+            if n.op_type == "MaxPool" and not model.is_fork_node(n):
+                consumer = model.find_consumer(n.output[0])
+                pads = get_by_name(n.attribute, "pads")
+                has_padding = False
+                if pads is not None:
+                    pads = list(pads.ints)
+                    has_padding = np.prod(pads) != 0
+                if consumer is not None and consumer.op_type == "MultiThreshold":
+                    mt_out = consumer.output[0]
+                    mt_odt = model.get_tensor_datatype(mt_out)
+                    if mt_odt.signed() and has_padding:
+                        warnings.warn(
+                            "Skipping padded MaxPool + signed-output MultiThreshold"
+                        )
+                        continue
+                    # check for non-decreasing thresholds and nonnegative
+                    # scale factor in MultiThreshold
+                    # otherwise we cannot do the reordering
+                    T = model.get_initializer(consumer.input[1])
+                    T_sorted = np.sort(T, axis=1)
+                    assert (
+                        T == T_sorted
+                    ).all(), "MultiThreshold must have non-decreasing thresholds"
+                    mt_inst = getCustomOp(consumer)
+                    if mt_inst.get_nodeattr("out_scale") < 0:
+                        warnings.warn("Skipping MultiThreshold with negative out_scale")
+                        continue
+
+                    # remove old nodes
+                    graph.node.remove(n)
+                    graph.node.remove(consumer)
+
+                    # swap conections
+                    group_in = n.input[0]
+                    # new tensor because dims change
+                    group_middle = model.make_new_valueinfo_name()
+                    group_out = consumer.output[0]
+
+                    consumer.input[0] = group_in
+                    consumer.output[0] = group_middle
+
+                    n.input[0] = group_middle
+                    n.output[0] = group_out
+
+                    # insert them back in
+                    graph.node.insert(node_ind - 1, consumer)
+                    graph.node.insert(node_ind, n)
+
+                    graph_modified = True
+
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
diff --git a/tests/transformation/test_move_maxpool_past_multithreshold.py b/tests/transformation/test_move_maxpool_past_multithreshold.py
new file mode 100644
index 0000000000000000000000000000000000000000..01005ddd16d2545f46b6b0f1f61538e9618e2a7b
--- /dev/null
+++ b/tests/transformation/test_move_maxpool_past_multithreshold.py
@@ -0,0 +1,101 @@
+from onnx import TensorProto, helper
+import numpy as np
+
+import finn.core.onnx_exec as oxe
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+
+
+def get_multithreshold_rand_params(channels, num_of_thres, seed=None):
+    if seed is not None:
+        np.random.seed(seed)
+    steps = np.random.rand(channels, 1) * 2
+    bias = np.random.rand(channels, 1) * 10
+    thres = [np.arange(num_of_thres) for chn in range(channels)]
+    thres = ((thres - bias) * steps).astype(np.float32)
+    return thres
+
+
+def test_move_maxpool_past_multithreshold():
+    # generate test vectors of correct shape
+    ch = 64
+    ifmdim = 16
+    ofmdim = 16 // 4
+    input_shape = (1, ch, ifmdim, ifmdim)
+    output_shape = (1, ch, ofmdim, ofmdim)
+
+    top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
+
+    maxpool_config = {}
+    maxpool_config["pads"] = [1, 1, 1, 1]
+    maxpool_config["kernel_shape"] = [3, 3]
+    maxpool_config["strides"] = [2, 2]
+
+    value_info = []
+    thres1_shape = [1, 1]
+    value_info += [
+        helper.make_tensor_value_info("thres1", TensorProto.FLOAT, thres1_shape)
+    ]
+
+    thres2_shape = [ch, 14]
+    value_info += [
+        helper.make_tensor_value_info("thres2", TensorProto.FLOAT, thres2_shape)
+    ]
+
+    nodes = []
+    nodes += [helper.make_node("MaxPool", ["top_in"], ["t1"], **maxpool_config)]
+    nodes += [
+        helper.make_node(
+            "MultiThreshold",
+            ["t1", "thres1"],
+            ["t2"],
+            domain="finn",
+            out_dtype="BIPOLAR",
+            out_bias=-1.0,
+            out_scale=1.0,
+        )
+    ]
+    nodes += [helper.make_node("MaxPool", ["t2"], ["t3"], **maxpool_config)]
+    nodes += [
+        helper.make_node(
+            "MultiThreshold",
+            ["t3", "thres2"],
+            ["top_out"],
+            domain="finn",
+            out_dtype="UINT4",
+            out_scale=-1.0,
+        )
+    ]
+
+    modelproto = helper.make_model(
+        helper.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=nodes,
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+
+    model.set_initializer("thres1", np.array([[0]]))
+    model.set_initializer(
+        "thres2", get_multithreshold_rand_params(*thres2_shape, seed=0)
+    )
+
+    # Transform
+    new_model = model.transform(MoveMaxPoolPastMultiThreshold())
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    # Test
+    assert oxe.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "MaxPool"
+    assert new_model.graph.node[1].op_type == "MultiThreshold"
+    assert new_model.graph.node[2].op_type == "MultiThreshold"
+    assert new_model.graph.node[3].op_type == "MaxPool"
+    assert len(new_model.graph.node) == 4