From 6bb39ed7ee39b5d4e8e0e2261337904edf323827 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 17 Jun 2020 23:21:36 +0100
Subject: [PATCH] [Core] rename validate -> sanitize_quant_values, add comments

plus a small bugfix (break -> continue)
---
 src/finn/core/onnx_exec.py |  8 ++++----
 src/finn/util/basic.py     | 29 +++++++++++++++++++++++------
 2 files changed, 27 insertions(+), 10 deletions(-)

diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index 0bacbc8b0..218df22e0 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -39,7 +39,7 @@ from finn.core.remote_exec import remote_exec
 from finn.core.rtlsim_exec import rtlsim_exec
 from finn.custom_op.registry import getCustomOp
 import finn.analysis.topology as ta
-from finn.util.basic import validate_quant_values
+from finn.util.basic import sanitize_quant_values
 
 
 def execute_node(node, context, graph):
@@ -103,7 +103,7 @@ def execute_node(node, context, graph):
                     raise Exception(
                         """Output shapes disagree after node execution:
                         found %s vs expected %s"""
-                        % (str(output_list[list_ind].shape), str(context[outp].shape),)
+                        % (str(output_list[list_ind].shape), str(context[outp].shape))
                     )
                 context[outp] = output_list[list_ind]
 
@@ -161,11 +161,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
         # topologically sorted
         for node in graph.node:
             # call util function match input values to quantization annotation
-            execution_context = validate_quant_values(
+            execution_context = sanitize_quant_values(
                 model, node.input, execution_context
             )
             execute_node(node, execution_context, graph)
-            execution_context = validate_quant_values(
+            execution_context = sanitize_quant_values(
                 model, node.output, execution_context
             )
     elif model_exec_mode == "remote_pynq":
diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py
index c4154aad4..585d6b90a 100644
--- a/src/finn/util/basic.py
+++ b/src/finn/util/basic.py
@@ -265,12 +265,29 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len):
     return (min_prod, max_prod)
 
 
-def validate_quant_values(model, node_tensors, execution_context, check_values=False):
+def sanitize_quant_values(model, node_tensors, execution_context, check_values=False):
+    """ Sanitize given list of tensors in execution_context by rounding values
+    that are supposed to be integers (as indicated by their quantization
+    annotation). Will raise an assertion if the amount of rounding is too large.
+    Returns the sanitized execution context.
+
+    If check_values is specified, an extra DataType.allowed() check will be
+    performed on any rounded tensors.
+
+    Background:
+    FINN uses floating point tensors as a carrier data type to represent
+    integers. Floating point arithmetic can introduce rounding errors, e.g.
+    (int_num * float_scale) / float_scale is not always equal to int_num.
+    We use this function to ensure that the values that are supposed to be
+    integers are indeed integers.
+    """
+
     for tensor in node_tensors:
         dtype = model.get_tensor_datatype(tensor)
-        # if FLOAT32 skip to next tensor
+        # floats don't need sanitization, skip to next
+        # introduces less quicker runtime
         if dtype == DataType.FLOAT32:
-            break
+            continue
         current_values = execution_context[tensor]
         updated_values = current_values
         has_to_be_rounded = False
@@ -283,8 +300,8 @@ def validate_quant_values(model, node_tensors, execution_context, check_values=F
             updated_values = np.round(current_values)
             warnings.warn(
                 "The values of tensor {} can't be represented "
-                "with the set finn datatype ({}), they will be rounded to match the "
-                "finn datatype.".format(tensor, dtype)
+                "with the set FINN datatype ({}), they will be rounded to match the "
+                "FINN datatype.".format(tensor, dtype)
             )
         # check if rounded values are not too far from original values
         max_error = max(np.abs(current_values - updated_values).flatten())
@@ -303,7 +320,7 @@ def validate_quant_values(model, node_tensors, execution_context, check_values=F
             execution_context[tensor] = updated_values
         else:
             raise Exception(
-                """Rounding error is too high to match set finn
+                """Rounding error is too high to match set FINN
             datatype ({}) for input {}""".format(
                     dtype, tensor
                 )
-- 
GitLab