diff --git a/src/finn/core/datatype.py b/src/finn/core/datatype.py
index 9323d0c42fc37c7dc95e21cf6305dc40dd605b5c..42a366aafcc002a433d0e03c03ef6a9bed6adede 100644
--- a/src/finn/core/datatype.py
+++ b/src/finn/core/datatype.py
@@ -24,27 +24,29 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-from enum import Enum
+from enum import Enum, auto
 
 import numpy as np
 
 
 class DataType(Enum):
-    FLOAT32 = 0
-    BINARY = 1
-    BIPOLAR = 2
-    UINT2 = 3
-    UINT3 = 4
-    UINT4 = 5
-    UINT8 = 6
-    UINT16 = 7
-    UINT32 = 8
-    INT2 = 9
-    INT3 = 10
-    INT4 = 11
-    INT8 = 12
-    INT16 = 13
-    INT32 = 14
+    # important to maintain ordering here: unsigned to signed, fewer to more
+    # bits. The get_smallest_possible() member function is dependent on this.
+    BINARY = auto()
+    UINT2 = auto()
+    UINT3 = auto()
+    UINT4 = auto()
+    UINT8 = auto()
+    UINT16 = auto()
+    UINT32 = auto()
+    BIPOLAR = auto()
+    INT2 = auto()
+    INT3 = auto()
+    INT4 = auto()
+    INT8 = auto()
+    INT16 = auto()
+    INT32 = auto()
+    FLOAT32 = auto()
 
     def bitwidth(self):
         """Returns the number of bits required for this DataType."""
@@ -109,3 +111,22 @@ class DataType(Enum):
             return value in [-1, +1]
         else:
             raise Exception("Unrecognized data type: %s" % self.name)
+
+    def get_smallest_possible(value):
+        """Return smallest (fewest bits) possible DataType that can represent
+      value. Prefers unsigned integers where possible."""
+        if not int(value) == value:
+            return DataType["FLOAT32"]
+        for k in DataType.__members__:
+            dt = DataType[k]
+            if (dt.min() <= value) and (value <= dt.max()):
+                return dt
+
+    def signed(self):
+        """Return whether this DataType can represent negative numbers."""
+        return self.min() < 0
+
+    def is_integer(self):
+        """Return whether this DataType represents integer values only."""
+        # only FLOAT32 is noninteger for now
+        return self != DataType.FLOAT32
diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
new file mode 100644
index 0000000000000000000000000000000000000000..19d947b57d045d4d7f2523f0f392adaba5bb367a
--- /dev/null
+++ b/src/finn/transformation/infer_datatypes.py
@@ -0,0 +1,47 @@
+from finn.core.datatype import DataType
+
+
+def _infer_node_datatype(model, node):
+    """Infer output datatype(s) for a particular node. Returns True if any
+    changes were made."""
+    idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
+    odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
+    if node.op_type == "MultiThreshold":
+        # number of thresholds decides # output buts, use get_smallest_possible
+        n_thres = model.get_tensor_shape(node.input[1])[1]
+        odtype = DataType.get_smallest_possible(n_thres)
+        model.set_tensor_datatype(node.output[0], odtype)
+    elif node.op_type == "Sign":
+        # always produces bipolar outputs
+        model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
+    elif node.op_type == "MatMul":
+        if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
+            # node has at least one float input, output is also float
+            model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
+        else:
+            # TODO compute minimum / maximum result to minimize bitwidth
+            # use (u)int32 accumulators for now
+            has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
+            if has_signed_inp:
+                odtype = DataType.INT32
+            else:
+                odtype = DataType.UINT32
+            model.set_tensor_datatype(node.output[0], odtype)
+    else:
+        # unknown, assume node produces float32 outputs
+        for o in node.output:
+            model.set_tensor_datatype(o, DataType.FLOAT32)
+    # compare old and new output dtypes to see if anything changed
+    new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
+    graph_modified = new_odtypes != odtypes
+    return graph_modified
+
+
+def infer_datatypes(model):
+    """Infer FINN DataType info for all intermediate/output tensors based on
+  inputs and node type."""
+    graph = model.graph
+    graph_modified = False
+    for node in graph.node:
+        graph_modified |= _infer_node_datatype(model, node)
+    return (model, graph_modified)
diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 309177e548be490a7de05aad6006b838b207f3bc..5053b23c3e09bc5f32807cdf16a997c4647dd165 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model):
                     graph.node.remove(consumer)
                     graph_modified = True
     return (model, graph_modified)
+
+
+def round_thresholds(model):
+    """For MultiThreshold nodes operating on integer inputs, round up
+    thresholds values to the nearest integer."""
+    graph = model.graph
+    graph_modified = False
+    for n in graph.node:
+        if n.op_type == "MultiThreshold":
+            idtype = model.get_tensor_datatype(n.input[0])
+            T = model.get_initializer(n.input[1])
+            Tnew = np.ceil(T)
+            if idtype.is_integer() and (T != Tnew).any():
+                # round up the thresholds to nearest integer
+                model.set_initializer(n.input[1], Tnew)
+                # use same datatype as inputs for thresholds
+                model.set_tensor_datatype(n.input[1], idtype)
+                graph_modified = True
+    return (model, graph_modified)
diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py
index c2929fa156a593eb7d44f63cafc55da05b7c3dae..559314117f03f505d8ddb4afb88ba440f4dcd966 100644
--- a/tests/test_datatypes.py
+++ b/tests/test_datatypes.py
@@ -1,35 +1,46 @@
 # -*- coding: utf-8 -*-
 
-import finn.core.datatype as dt
+from finn.core.datatype import DataType
 
 
 def test_datatypes():
-    assert dt.DataType.BIPOLAR.allowed(-1)
-    assert dt.DataType.BIPOLAR.allowed(0) is False
-    assert dt.DataType.BINARY.allowed(-1) is False
-    assert dt.DataType.BINARY.allowed(1)
-    assert dt.DataType.UINT2.allowed(2)
-    assert dt.DataType.UINT2.allowed(10) is False
-    assert dt.DataType.UINT3.allowed(5)
-    assert dt.DataType.UINT3.allowed(-7) is False
-    assert dt.DataType.UINT4.allowed(15)
-    assert dt.DataType.UINT4.allowed(150) is False
-    assert dt.DataType.UINT8.allowed(150)
-    assert dt.DataType.UINT8.allowed(777) is False
-    assert dt.DataType.UINT16.allowed(14500)
-    assert dt.DataType.UINT16.allowed(-1) is False
-    assert dt.DataType.UINT32.allowed(2 ** 10)
-    assert dt.DataType.UINT32.allowed(-1) is False
-    assert dt.DataType.INT2.allowed(-1)
-    assert dt.DataType.INT2.allowed(-10) is False
-    assert dt.DataType.INT3.allowed(5) is False
-    assert dt.DataType.INT3.allowed(-2)
-    assert dt.DataType.INT4.allowed(15) is False
-    assert dt.DataType.INT4.allowed(-5)
-    assert dt.DataType.INT8.allowed(150) is False
-    assert dt.DataType.INT8.allowed(-127)
-    assert dt.DataType.INT16.allowed(-1.04) is False
-    assert dt.DataType.INT16.allowed(-7777)
-    assert dt.DataType.INT32.allowed(7.77) is False
-    assert dt.DataType.INT32.allowed(-5)
-    assert dt.DataType.INT32.allowed(5)
+    assert DataType.BIPOLAR.allowed(-1)
+    assert DataType.BIPOLAR.allowed(0) is False
+    assert DataType.BINARY.allowed(-1) is False
+    assert DataType.BINARY.allowed(1)
+    assert DataType.UINT2.allowed(2)
+    assert DataType.UINT2.allowed(10) is False
+    assert DataType.UINT3.allowed(5)
+    assert DataType.UINT3.allowed(-7) is False
+    assert DataType.UINT4.allowed(15)
+    assert DataType.UINT4.allowed(150) is False
+    assert DataType.UINT8.allowed(150)
+    assert DataType.UINT8.allowed(777) is False
+    assert DataType.UINT16.allowed(14500)
+    assert DataType.UINT16.allowed(-1) is False
+    assert DataType.UINT32.allowed(2 ** 10)
+    assert DataType.UINT32.allowed(-1) is False
+    assert DataType.INT2.allowed(-1)
+    assert DataType.INT2.allowed(-10) is False
+    assert DataType.INT3.allowed(5) is False
+    assert DataType.INT3.allowed(-2)
+    assert DataType.INT4.allowed(15) is False
+    assert DataType.INT4.allowed(-5)
+    assert DataType.INT8.allowed(150) is False
+    assert DataType.INT8.allowed(-127)
+    assert DataType.INT16.allowed(-1.04) is False
+    assert DataType.INT16.allowed(-7777)
+    assert DataType.INT32.allowed(7.77) is False
+    assert DataType.INT32.allowed(-5)
+    assert DataType.INT32.allowed(5)
+    assert DataType.BINARY.signed() is False
+    assert DataType.FLOAT32.signed()
+    assert DataType.BIPOLAR.signed()
+
+
+def test_smallest_possible():
+    assert DataType.get_smallest_possible(1) == DataType.BINARY
+    assert DataType.get_smallest_possible(1.1) == DataType.FLOAT32
+    assert DataType.get_smallest_possible(-1) == DataType.BIPOLAR
+    assert DataType.get_smallest_possible(-3) == DataType.INT3
+    assert DataType.get_smallest_possible(-3.2) == DataType.FLOAT32
diff --git a/tests/test_infer_datatypes.py b/tests/test_infer_datatypes.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd1c99a1da5f5bdb93f0b8b089fddb20e3c078ed
--- /dev/null
+++ b/tests/test_infer_datatypes.py
@@ -0,0 +1,40 @@
+import os
+
+import brevitas.onnx as bo
+import torch
+from models.LFC import LFC
+
+import finn.transformation.fold_constants as fc
+import finn.transformation.general as tg
+import finn.transformation.infer_datatypes as id
+import finn.transformation.infer_shapes as si
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+
+export_onnx_path = "test_output_lfc.onnx"
+# TODO get from config instead, hardcoded to Docker path for now
+trained_lfc_checkpoint = (
+    "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
+)
+
+
+def test_infer_datatypes():
+    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
+    checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
+    lfc.load_state_dict(checkpoint["state_dict"])
+    bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform_single(si.infer_shapes)
+    model = model.transform_repeated(fc.fold_constants)
+    model = model.transform_single(tg.give_unique_node_names)
+    model = model.transform_single(tg.give_readable_tensor_names)
+    model = model.transform_repeated(id.infer_datatypes)
+    assert model.get_tensor_datatype("MatMul_0_out0") == DataType.INT32
+    assert model.get_tensor_datatype("MatMul_1_out0") == DataType.INT32
+    assert model.get_tensor_datatype("MatMul_2_out0") == DataType.INT32
+    assert model.get_tensor_datatype("MatMul_3_out0") == DataType.INT32
+    assert model.get_tensor_datatype("Sign_0_out0") == DataType.BIPOLAR
+    assert model.get_tensor_datatype("Sign_1_out0") == DataType.BIPOLAR
+    assert model.get_tensor_datatype("Sign_2_out0") == DataType.BIPOLAR
+    assert model.get_tensor_datatype("Sign_3_out0") == DataType.BIPOLAR
+    os.remove(export_onnx_path)
diff --git a/tests/test_round_thresholds.py b/tests/test_round_thresholds.py
new file mode 100644
index 0000000000000000000000000000000000000000..28add7313201f1e30ecaf98d841f762a3f859835
--- /dev/null
+++ b/tests/test_round_thresholds.py
@@ -0,0 +1,38 @@
+import numpy as np
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+import finn.transformation.streamline as sl
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+
+
+def test_round_thresholds():
+    v = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 4])
+    thresholds = helper.make_tensor_value_info("thresholds", TensorProto.FLOAT, [4, 1])
+    out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4])
+    node_def = helper.make_node(
+        "MultiThreshold", ["v", "thresholds"], ["out"], domain="finn"
+    )
+    graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
+    model_def = helper.make_model(graph_def)
+    model = ModelWrapper(model_def)
+    threshold_val = np.asarray([[-1.1], [0.7], [2.3], [5.1]], dtype=np.float32)
+    model.set_initializer("thresholds", threshold_val)
+    model.set_tensor_datatype("v", DataType.INT8)
+    inp_dict_f = {"v": np.floor(threshold_val).T}
+    inp_dict_n = {"v": np.round(threshold_val).T}
+    inp_dict_c = {"v": np.ceil(threshold_val).T}
+    orig_f = oxe.execute_onnx(model, inp_dict_f)["out"]
+    orig_n = oxe.execute_onnx(model, inp_dict_n)["out"]
+    orig_c = oxe.execute_onnx(model, inp_dict_c)["out"]
+    assert model.get_tensor_datatype("thresholds") == DataType.FLOAT32
+    new_model = model.transform_repeated(sl.round_thresholds)
+    # rounded up thresholds should have same dtype as input
+    assert new_model.get_tensor_datatype("thresholds") == DataType.INT8
+    new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"]
+    new_n = oxe.execute_onnx(new_model, inp_dict_n)["out"]
+    new_c = oxe.execute_onnx(new_model, inp_dict_c)["out"]
+    assert np.isclose(orig_f, new_f, atol=1e-3).all()
+    assert np.isclose(orig_n, new_n, atol=1e-3).all()
+    assert np.isclose(orig_c, new_c, atol=1e-3).all()
diff --git a/tests/test_streamline.py b/tests/test_streamline.py
index e3d97a4364ce1c7dd46b9e4ed7e66d7076751fd0..1a93c6ff128b2169da4802d7367608f499e89beb 100644
--- a/tests/test_streamline.py
+++ b/tests/test_streamline.py
@@ -12,6 +12,7 @@ import finn.core.onnx_exec as oxe
 import finn.transformation.batchnorm_to_affine as ba
 import finn.transformation.fold_constants as fc
 import finn.transformation.general as tg
+import finn.transformation.infer_datatypes as di
 import finn.transformation.infer_shapes as si
 import finn.transformation.streamline as sl
 from finn.core.modelwrapper import ModelWrapper
@@ -53,12 +54,14 @@ def test_streamline_lfc_w1a1():
         sl.factor_out_mul_sign_magnitude,
         sl.absorb_mul_into_multi_threshold,
         sl.absorb_1bit_mul_into_matmul,
+        sl.round_thresholds,
     ]
     trn_ind = 0
     for trn in transforms:
         model = model.transform_repeated(trn)
         model = model.transform_single(tg.give_unique_node_names)
         model = model.transform_single(tg.give_readable_tensor_names)
+        model = model.transform_repeated(di.infer_datatypes)
         produced_ctx = oxe.execute_onnx(model, input_dict, True)
         produced = produced_ctx[model.graph.output[0].name]
         # model.save("%d-%s.onnx" % (trn_ind, trn.__name__))