diff --git a/docs/finn/internals.rst b/docs/finn/internals.rst
index 7a4bc687eeb827320991f7d3f1ef8cc35e97f3da..010cdece978cde078c3df4c64177fa1c5455aa0a 100644
--- a/docs/finn/internals.rst
+++ b/docs/finn/internals.rst
@@ -16,6 +16,9 @@ Custom Quantization Annotations
 
 ONNX does not support datatypes smaller than 8-bit integers, whereas in FINN we are interested in smaller integers down to ternary and bipolar. To make this work, FINN uses the quantization_annotation field in ONNX to annotate tensors with their FINN DataType (:py:mod:`finn.core.datatype.DataType`) information. However, all tensors are expected to use single-precision floating point (float32) storage in FINN. This means we store even a 1-bit value as floating point for the purposes of representation. The FINN compiler flow is responsible for eventually producing a packed representation for the target hardware, where the 1-bit is actually stored as 1-bit.
 
+Note that FINN uses floating point tensors as a carrier data type to represent integers. Floating point arithmetic can introduce rounding errors, e.g. (int_num * float_scale) / float_scale is not always equal to int_num.
+When using the custom ONNX execution flow, FINN will attempt to sanitize any rounding errors for integer tensors. See (:py:mod:`finn.util.basic.sanitize_quant_values`) for more information.
+
 Custom Operations/Nodes
 =======================
 
diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index c2f68a35076418e0cf2edb578bdb8d548772fc78..218df22e07537034b377abc077aa7902bc0c4cfc 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -39,6 +39,7 @@ from finn.core.remote_exec import remote_exec
 from finn.core.rtlsim_exec import rtlsim_exec
 from finn.custom_op.registry import getCustomOp
 import finn.analysis.topology as ta
+from finn.util.basic import sanitize_quant_values
 
 
 def execute_node(node, context, graph):
@@ -102,10 +103,7 @@ def execute_node(node, context, graph):
                     raise Exception(
                         """Output shapes disagree after node execution:
                         found %s vs expected %s"""
-                        % (
-                            str(output_list[list_ind].shape),
-                            str(context[outp].shape),
-                        )
+                        % (str(output_list[list_ind].shape), str(context[outp].shape))
                     )
                 context[outp] = output_list[list_ind]
 
@@ -162,7 +160,14 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
         # we can simply walk down the list since the ONNX spec guarantees that it is
         # topologically sorted
         for node in graph.node:
+            # call util function match input values to quantization annotation
+            execution_context = sanitize_quant_values(
+                model, node.input, execution_context
+            )
             execute_node(node, execution_context, graph)
+            execution_context = sanitize_quant_values(
+                model, node.output, execution_context
+            )
     elif model_exec_mode == "remote_pynq":
         # use remote exec metadata built into model to execute on a remote PYNQ
         remote_exec(model, execution_context)
diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py
index eb3d46bcd66e3dc307a679e6b8dfbb9913398d36..585d6b90a59ca9f3dac56a6133de705c2f56f025 100644
--- a/src/finn/util/basic.py
+++ b/src/finn/util/basic.py
@@ -31,6 +31,7 @@ import random
 import string
 import subprocess
 import tempfile
+import warnings
 
 import numpy as np
 
@@ -264,6 +265,69 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len):
     return (min_prod, max_prod)
 
 
+def sanitize_quant_values(model, node_tensors, execution_context, check_values=False):
+    """ Sanitize given list of tensors in execution_context by rounding values
+    that are supposed to be integers (as indicated by their quantization
+    annotation). Will raise an assertion if the amount of rounding is too large.
+    Returns the sanitized execution context.
+
+    If check_values is specified, an extra DataType.allowed() check will be
+    performed on any rounded tensors.
+
+    Background:
+    FINN uses floating point tensors as a carrier data type to represent
+    integers. Floating point arithmetic can introduce rounding errors, e.g.
+    (int_num * float_scale) / float_scale is not always equal to int_num.
+    We use this function to ensure that the values that are supposed to be
+    integers are indeed integers.
+    """
+
+    for tensor in node_tensors:
+        dtype = model.get_tensor_datatype(tensor)
+        # floats don't need sanitization, skip to next
+        # introduces less quicker runtime
+        if dtype == DataType.FLOAT32:
+            continue
+        current_values = execution_context[tensor]
+        updated_values = current_values
+        has_to_be_rounded = False
+        # TODO: vectorize with numpy
+        for value in np.nditer(current_values):
+            if not dtype.allowed(value):
+                has_to_be_rounded = True
+                break
+        if has_to_be_rounded:
+            updated_values = np.round(current_values)
+            warnings.warn(
+                "The values of tensor {} can't be represented "
+                "with the set FINN datatype ({}), they will be rounded to match the "
+                "FINN datatype.".format(tensor, dtype)
+            )
+        # check if rounded values are not too far from original values
+        max_error = max(np.abs(current_values - updated_values).flatten())
+        if max_error <= 1e-4:
+            if check_values is True:
+                # check again if values can now be represented with set finn datatype
+                # TODO: vectorize with numpy
+                for value in np.nditer(updated_values):
+                    if not dtype.allowed(value):
+                        raise Exception(
+                            """Values can't be represented with set
+                                finn datatype ({}) for input {}""".format(
+                                dtype, tensor
+                            )
+                        )
+            execution_context[tensor] = updated_values
+        else:
+            raise Exception(
+                """Rounding error is too high to match set FINN
+            datatype ({}) for input {}""".format(
+                    dtype, tensor
+                )
+            )
+    return execution_context
+
+
 class CppBuilder:
     """Builds the g++ compiler command to produces the executable of the c++ code
     in code_gen_dir which is passed to the function build() of this class."""
diff --git a/tests/core/test_basic_onnx_exec.py b/tests/core/test_basic_onnx_exec.py
index a7b6da9965aa5912870812a8c1f8d6da2ee0d181..7b0412432cc6360cb9c42d66417bd187ed142563 100644
--- a/tests/core/test_basic_onnx_exec.py
+++ b/tests/core/test_basic_onnx_exec.py
@@ -35,6 +35,8 @@ import onnx.numpy_helper as np_helper
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
+from finn.core.datatype import DataType
+from finn.util.basic import gen_finn_dt_tensor
 
 
 def test_mnist_onnx_download_extract_run():
@@ -53,3 +55,30 @@ def test_mnist_onnx_download_extract_run():
     assert np.isclose(
         np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
     ).all()
+
+
+def test_onnx_exec_internal_rounding():
+    inp0 = onnx.helper.make_tensor_value_info("inp0", onnx.TensorProto.FLOAT, [2, 2])
+    inp1 = onnx.helper.make_tensor_value_info("inp1", onnx.TensorProto.FLOAT, [1])
+    outp = onnx.helper.make_tensor_value_info("outp", onnx.TensorProto.FLOAT, [2, 2])
+    mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"],)
+    graph = onnx.helper.make_graph(
+        nodes=[mul_node], name="mul_graph", inputs=[inp0, inp1], outputs=[outp]
+    )
+
+    model = onnx.helper.make_model(graph, producer_name="mul-model")
+    model = ModelWrapper(model)
+    idt = DataType.INT2
+    model.set_tensor_datatype("inp0", idt)
+    model.set_tensor_datatype("inp1", idt)
+    model.transform(InferShapes())
+
+    mul_value = np.asarray([-1], dtype=np.float32)
+    inp_int = gen_finn_dt_tensor(idt, [2, 2])
+    scale = np.random.uniform(low=0, high=1, size=(2, 2)).astype(np.float32)
+    inp_rounded = (inp_int * scale) / (scale + 1e-7)
+    input_dict = {"inp0": inp_rounded, "inp1": mul_value}
+    output_dict = oxe.execute_onnx(model, input_dict)
+    produced = output_dict["outp"]
+    expected = np.multiply(inp_int, mul_value)
+    assert (produced == expected).all()