diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 7d192aebfd4a7a8c52c479cf060f4d539797934d..3fe1ef242364b1070726681495ae3632660f2904 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -27,6 +27,7 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import numpy as np
+import warnings
 from onnx import helper as oh
 
 from finn.transformation import Transformation
@@ -546,16 +547,10 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
             if n.op_type == "MaxPool" and not model.is_fork_node(n):
                 consumer = model.find_consumer(n.output[0])
                 if consumer is not None and consumer.op_type == "MultiThreshold":
-                    is_signed = True
-                    for attr in consumer.attribute:
-                        if (
-                            attr.name == "out_dtype"
-                            and len(attr.s) >= 5
-                            and attr.s[:4] == b"UINT"
-                        ):
-                            is_signed = False
-
-                    if is_signed:
+                    mt_out = consumer.output[0]
+                    mt_odt = model.get_tensor_datatype(mt_out)
+                    if mt_odt.signed():
+                        warnings.warn("Skipping signed-output MultiThreshold")
                         continue
 
                     # remove old nodes
diff --git a/tests/transformation/test_move_maxpool_past_multithreshold.py b/tests/transformation/test_move_maxpool_past_multithreshold.py
index 0378d956f4bc3160cd93e4a1e27151c71cd40a89..0a40c83f4bd3c35e5bf8d31917b2019479b67268 100644
--- a/tests/transformation/test_move_maxpool_past_multithreshold.py
+++ b/tests/transformation/test_move_maxpool_past_multithreshold.py
@@ -5,6 +5,7 @@ import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
 
 
 def get_multithreshold_rand_params(channels, num_of_thres, seed=None):
@@ -79,6 +80,7 @@ def test_move_maxpool_past_multithreshold():
     )
     model = ModelWrapper(modelproto)
     model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
 
     model.set_initializer("thres1", np.array([[0]]))
     model.set_initializer(