diff --git a/src/finn/core/datatype.py b/src/finn/core/datatype.py
index b71922d479b62e5eba07ddd159e2e3243fa77d95..42a366aafcc002a433d0e03c03ef6a9bed6adede 100644
--- a/src/finn/core/datatype.py
+++ b/src/finn/core/datatype.py
@@ -125,3 +125,8 @@ class DataType(Enum):
     def signed(self):
         """Return whether this DataType can represent negative numbers."""
         return self.min() < 0
+
+    def is_integer(self):
+        """Return whether this DataType represents integer values only."""
+        # only FLOAT32 is noninteger for now
+        return self != DataType.FLOAT32
diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 309177e548be490a7de05aad6006b838b207f3bc..5053b23c3e09bc5f32807cdf16a997c4647dd165 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model):
                     graph.node.remove(consumer)
                     graph_modified = True
     return (model, graph_modified)
+
+
+def round_thresholds(model):
+    """For MultiThreshold nodes operating on integer inputs, round up
+    thresholds values to the nearest integer."""
+    graph = model.graph
+    graph_modified = False
+    for n in graph.node:
+        if n.op_type == "MultiThreshold":
+            idtype = model.get_tensor_datatype(n.input[0])
+            T = model.get_initializer(n.input[1])
+            Tnew = np.ceil(T)
+            if idtype.is_integer() and (T != Tnew).any():
+                # round up the thresholds to nearest integer
+                model.set_initializer(n.input[1], Tnew)
+                # use same datatype as inputs for thresholds
+                model.set_tensor_datatype(n.input[1], idtype)
+                graph_modified = True
+    return (model, graph_modified)
diff --git a/tests/test_round_thresholds.py b/tests/test_round_thresholds.py
new file mode 100644
index 0000000000000000000000000000000000000000..28add7313201f1e30ecaf98d841f762a3f859835
--- /dev/null
+++ b/tests/test_round_thresholds.py
@@ -0,0 +1,38 @@
+import numpy as np
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+import finn.transformation.streamline as sl
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+
+
+def test_round_thresholds():
+    v = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 4])
+    thresholds = helper.make_tensor_value_info("thresholds", TensorProto.FLOAT, [4, 1])
+    out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4])
+    node_def = helper.make_node(
+        "MultiThreshold", ["v", "thresholds"], ["out"], domain="finn"
+    )
+    graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
+    model_def = helper.make_model(graph_def)
+    model = ModelWrapper(model_def)
+    threshold_val = np.asarray([[-1.1], [0.7], [2.3], [5.1]], dtype=np.float32)
+    model.set_initializer("thresholds", threshold_val)
+    model.set_tensor_datatype("v", DataType.INT8)
+    inp_dict_f = {"v": np.floor(threshold_val).T}
+    inp_dict_n = {"v": np.round(threshold_val).T}
+    inp_dict_c = {"v": np.ceil(threshold_val).T}
+    orig_f = oxe.execute_onnx(model, inp_dict_f)["out"]
+    orig_n = oxe.execute_onnx(model, inp_dict_n)["out"]
+    orig_c = oxe.execute_onnx(model, inp_dict_c)["out"]
+    assert model.get_tensor_datatype("thresholds") == DataType.FLOAT32
+    new_model = model.transform_repeated(sl.round_thresholds)
+    # rounded up thresholds should have same dtype as input
+    assert new_model.get_tensor_datatype("thresholds") == DataType.INT8
+    new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"]
+    new_n = oxe.execute_onnx(new_model, inp_dict_n)["out"]
+    new_c = oxe.execute_onnx(new_model, inp_dict_c)["out"]
+    assert np.isclose(orig_f, new_f, atol=1e-3).all()
+    assert np.isclose(orig_n, new_n, atol=1e-3).all()
+    assert np.isclose(orig_c, new_c, atol=1e-3).all()
diff --git a/tests/test_streamline.py b/tests/test_streamline.py
index e3d97a4364ce1c7dd46b9e4ed7e66d7076751fd0..1a93c6ff128b2169da4802d7367608f499e89beb 100644
--- a/tests/test_streamline.py
+++ b/tests/test_streamline.py
@@ -12,6 +12,7 @@ import finn.core.onnx_exec as oxe
 import finn.transformation.batchnorm_to_affine as ba
 import finn.transformation.fold_constants as fc
 import finn.transformation.general as tg
+import finn.transformation.infer_datatypes as di
 import finn.transformation.infer_shapes as si
 import finn.transformation.streamline as sl
 from finn.core.modelwrapper import ModelWrapper
@@ -53,12 +54,14 @@ def test_streamline_lfc_w1a1():
         sl.factor_out_mul_sign_magnitude,
         sl.absorb_mul_into_multi_threshold,
         sl.absorb_1bit_mul_into_matmul,
+        sl.round_thresholds,
     ]
     trn_ind = 0
     for trn in transforms:
         model = model.transform_repeated(trn)
         model = model.transform_single(tg.give_unique_node_names)
         model = model.transform_single(tg.give_readable_tensor_names)
+        model = model.transform_repeated(di.infer_datatypes)
         produced_ctx = oxe.execute_onnx(model, input_dict, True)
         produced = produced_ctx[model.graph.output[0].name]
         # model.save("%d-%s.onnx" % (trn_ind, trn.__name__))