diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py index 0bacbc8b04e3ce63d7234008a98b25b798efab47..218df22e07537034b377abc077aa7902bc0c4cfc 100644 --- a/src/finn/core/onnx_exec.py +++ b/src/finn/core/onnx_exec.py @@ -39,7 +39,7 @@ from finn.core.remote_exec import remote_exec from finn.core.rtlsim_exec import rtlsim_exec from finn.custom_op.registry import getCustomOp import finn.analysis.topology as ta -from finn.util.basic import validate_quant_values +from finn.util.basic import sanitize_quant_values def execute_node(node, context, graph): @@ -103,7 +103,7 @@ def execute_node(node, context, graph): raise Exception( """Output shapes disagree after node execution: found %s vs expected %s""" - % (str(output_list[list_ind].shape), str(context[outp].shape),) + % (str(output_list[list_ind].shape), str(context[outp].shape)) ) context[outp] = output_list[list_ind] @@ -161,11 +161,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False): # topologically sorted for node in graph.node: # call util function match input values to quantization annotation - execution_context = validate_quant_values( + execution_context = sanitize_quant_values( model, node.input, execution_context ) execute_node(node, execution_context, graph) - execution_context = validate_quant_values( + execution_context = sanitize_quant_values( model, node.output, execution_context ) elif model_exec_mode == "remote_pynq": diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py index c4154aad41cc33e3176b24293676d33168bca2bc..585d6b90a59ca9f3dac56a6133de705c2f56f025 100644 --- a/src/finn/util/basic.py +++ b/src/finn/util/basic.py @@ -265,12 +265,29 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len): return (min_prod, max_prod) -def validate_quant_values(model, node_tensors, execution_context, check_values=False): +def sanitize_quant_values(model, node_tensors, execution_context, check_values=False): + """ Sanitize given list of tensors in execution_context by rounding values + that are supposed to be integers (as indicated by their quantization + annotation). Will raise an assertion if the amount of rounding is too large. + Returns the sanitized execution context. + + If check_values is specified, an extra DataType.allowed() check will be + performed on any rounded tensors. + + Background: + FINN uses floating point tensors as a carrier data type to represent + integers. Floating point arithmetic can introduce rounding errors, e.g. + (int_num * float_scale) / float_scale is not always equal to int_num. + We use this function to ensure that the values that are supposed to be + integers are indeed integers. + """ + for tensor in node_tensors: dtype = model.get_tensor_datatype(tensor) - # if FLOAT32 skip to next tensor + # floats don't need sanitization, skip to next + # introduces less quicker runtime if dtype == DataType.FLOAT32: - break + continue current_values = execution_context[tensor] updated_values = current_values has_to_be_rounded = False @@ -283,8 +300,8 @@ def validate_quant_values(model, node_tensors, execution_context, check_values=F updated_values = np.round(current_values) warnings.warn( "The values of tensor {} can't be represented " - "with the set finn datatype ({}), they will be rounded to match the " - "finn datatype.".format(tensor, dtype) + "with the set FINN datatype ({}), they will be rounded to match the " + "FINN datatype.".format(tensor, dtype) ) # check if rounded values are not too far from original values max_error = max(np.abs(current_values - updated_values).flatten()) @@ -303,7 +320,7 @@ def validate_quant_values(model, node_tensors, execution_context, check_values=F execution_context[tensor] = updated_values else: raise Exception( - """Rounding error is too high to match set finn + """Rounding error is too high to match set FINN datatype ({}) for input {}""".format( dtype, tensor )