From 6fe1f90ee77a2cbc23087a9dd451440ac09b876e Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Mon, 15 Jun 2020 10:53:03 +0100
Subject: [PATCH] [Util] Update fct to validate quant values (update execution
 context)

---
 src/finn/util/basic.py | 38 ++++++++++++++++++++++----------------
 1 file changed, 22 insertions(+), 16 deletions(-)

diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py
index 08d92b388..c4154aad4 100644
--- a/src/finn/util/basic.py
+++ b/src/finn/util/basic.py
@@ -265,12 +265,16 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len):
     return (min_prod, max_prod)
 
 
-def update_execution_context(model, node, execution_context):
-    for inp in node.input:
-        dtype = model.get_tensor_datatype(inp)
-        current_values = execution_context[inp]
+def validate_quant_values(model, node_tensors, execution_context, check_values=False):
+    for tensor in node_tensors:
+        dtype = model.get_tensor_datatype(tensor)
+        # if FLOAT32 skip to next tensor
+        if dtype == DataType.FLOAT32:
+            break
+        current_values = execution_context[tensor]
         updated_values = current_values
         has_to_be_rounded = False
+        # TODO: vectorize with numpy
         for value in np.nditer(current_values):
             if not dtype.allowed(value):
                 has_to_be_rounded = True
@@ -280,26 +284,28 @@ def update_execution_context(model, node, execution_context):
             warnings.warn(
                 "The values of tensor {} can't be represented "
                 "with the set finn datatype ({}), they will be rounded to match the "
-                "finn datatype.".format(inp, dtype)
+                "finn datatype.".format(tensor, dtype)
             )
         # check if rounded values are not too far from original values
         max_error = max(np.abs(current_values - updated_values).flatten())
         if max_error <= 1e-4:
-            # check again if values can now be represented with set finn datatype
-            for value in np.nditer(updated_values):
-                if not dtype.allowed(value):
-                    raise Exception(
-                        """Values can't be represented with set
-                            finn datatype ({}) for input {}""".format(
-                            dtype, inp
+            if check_values is True:
+                # check again if values can now be represented with set finn datatype
+                # TODO: vectorize with numpy
+                for value in np.nditer(updated_values):
+                    if not dtype.allowed(value):
+                        raise Exception(
+                            """Values can't be represented with set
+                                finn datatype ({}) for input {}""".format(
+                                dtype, tensor
+                            )
                         )
-                    )
-            execution_context[inp] = updated_values
+            execution_context[tensor] = updated_values
         else:
             raise Exception(
-                """Values can't be rounded to match set finn
+                """Rounding error is too high to match set finn
             datatype ({}) for input {}""".format(
-                    dtype, inp
+                    dtype, tensor
                 )
             )
     return execution_context
-- 
GitLab