From 6bb39ed7ee39b5d4e8e0e2261337904edf323827 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Wed, 17 Jun 2020 23:21:36 +0100 Subject: [PATCH] [Core] rename validate -> sanitize_quant_values, add comments plus a small bugfix (break -> continue) --- src/finn/core/onnx_exec.py | 8 ++++---- src/finn/util/basic.py | 29 +++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py index 0bacbc8b0..218df22e0 100644 --- a/src/finn/core/onnx_exec.py +++ b/src/finn/core/onnx_exec.py @@ -39,7 +39,7 @@ from finn.core.remote_exec import remote_exec from finn.core.rtlsim_exec import rtlsim_exec from finn.custom_op.registry import getCustomOp import finn.analysis.topology as ta -from finn.util.basic import validate_quant_values +from finn.util.basic import sanitize_quant_values def execute_node(node, context, graph): @@ -103,7 +103,7 @@ def execute_node(node, context, graph): raise Exception( """Output shapes disagree after node execution: found %s vs expected %s""" - % (str(output_list[list_ind].shape), str(context[outp].shape),) + % (str(output_list[list_ind].shape), str(context[outp].shape)) ) context[outp] = output_list[list_ind] @@ -161,11 +161,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False): # topologically sorted for node in graph.node: # call util function match input values to quantization annotation - execution_context = validate_quant_values( + execution_context = sanitize_quant_values( model, node.input, execution_context ) execute_node(node, execution_context, graph) - execution_context = validate_quant_values( + execution_context = sanitize_quant_values( model, node.output, execution_context ) elif model_exec_mode == "remote_pynq": diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py index c4154aad4..585d6b90a 100644 --- a/src/finn/util/basic.py +++ b/src/finn/util/basic.py @@ -265,12 +265,29 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len): return (min_prod, max_prod) -def validate_quant_values(model, node_tensors, execution_context, check_values=False): +def sanitize_quant_values(model, node_tensors, execution_context, check_values=False): + """ Sanitize given list of tensors in execution_context by rounding values + that are supposed to be integers (as indicated by their quantization + annotation). Will raise an assertion if the amount of rounding is too large. + Returns the sanitized execution context. + + If check_values is specified, an extra DataType.allowed() check will be + performed on any rounded tensors. + + Background: + FINN uses floating point tensors as a carrier data type to represent + integers. Floating point arithmetic can introduce rounding errors, e.g. + (int_num * float_scale) / float_scale is not always equal to int_num. + We use this function to ensure that the values that are supposed to be + integers are indeed integers. + """ + for tensor in node_tensors: dtype = model.get_tensor_datatype(tensor) - # if FLOAT32 skip to next tensor + # floats don't need sanitization, skip to next + # introduces less quicker runtime if dtype == DataType.FLOAT32: - break + continue current_values = execution_context[tensor] updated_values = current_values has_to_be_rounded = False @@ -283,8 +300,8 @@ def validate_quant_values(model, node_tensors, execution_context, check_values=F updated_values = np.round(current_values) warnings.warn( "The values of tensor {} can't be represented " - "with the set finn datatype ({}), they will be rounded to match the " - "finn datatype.".format(tensor, dtype) + "with the set FINN datatype ({}), they will be rounded to match the " + "FINN datatype.".format(tensor, dtype) ) # check if rounded values are not too far from original values max_error = max(np.abs(current_values - updated_values).flatten()) @@ -303,7 +320,7 @@ def validate_quant_values(model, node_tensors, execution_context, check_values=F execution_context[tensor] = updated_values else: raise Exception( - """Rounding error is too high to match set finn + """Rounding error is too high to match set FINN datatype ({}) for input {}""".format( dtype, tensor ) -- GitLab