Skip to content
Snippets Groups Projects
Commit 6bb39ed7 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Core] rename validate -> sanitize_quant_values, add comments

plus a small bugfix (break -> continue)
parent 9409e444
No related branches found
No related tags found
No related merge requests found
......@@ -39,7 +39,7 @@ from finn.core.remote_exec import remote_exec
from finn.core.rtlsim_exec import rtlsim_exec
from finn.custom_op.registry import getCustomOp
import finn.analysis.topology as ta
from finn.util.basic import validate_quant_values
from finn.util.basic import sanitize_quant_values
def execute_node(node, context, graph):
......@@ -103,7 +103,7 @@ def execute_node(node, context, graph):
raise Exception(
"""Output shapes disagree after node execution:
found %s vs expected %s"""
% (str(output_list[list_ind].shape), str(context[outp].shape),)
% (str(output_list[list_ind].shape), str(context[outp].shape))
)
context[outp] = output_list[list_ind]
......@@ -161,11 +161,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
# topologically sorted
for node in graph.node:
# call util function match input values to quantization annotation
execution_context = validate_quant_values(
execution_context = sanitize_quant_values(
model, node.input, execution_context
)
execute_node(node, execution_context, graph)
execution_context = validate_quant_values(
execution_context = sanitize_quant_values(
model, node.output, execution_context
)
elif model_exec_mode == "remote_pynq":
......
......@@ -265,12 +265,29 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len):
return (min_prod, max_prod)
def validate_quant_values(model, node_tensors, execution_context, check_values=False):
def sanitize_quant_values(model, node_tensors, execution_context, check_values=False):
""" Sanitize given list of tensors in execution_context by rounding values
that are supposed to be integers (as indicated by their quantization
annotation). Will raise an assertion if the amount of rounding is too large.
Returns the sanitized execution context.
If check_values is specified, an extra DataType.allowed() check will be
performed on any rounded tensors.
Background:
FINN uses floating point tensors as a carrier data type to represent
integers. Floating point arithmetic can introduce rounding errors, e.g.
(int_num * float_scale) / float_scale is not always equal to int_num.
We use this function to ensure that the values that are supposed to be
integers are indeed integers.
"""
for tensor in node_tensors:
dtype = model.get_tensor_datatype(tensor)
# if FLOAT32 skip to next tensor
# floats don't need sanitization, skip to next
# introduces less quicker runtime
if dtype == DataType.FLOAT32:
break
continue
current_values = execution_context[tensor]
updated_values = current_values
has_to_be_rounded = False
......@@ -283,8 +300,8 @@ def validate_quant_values(model, node_tensors, execution_context, check_values=F
updated_values = np.round(current_values)
warnings.warn(
"The values of tensor {} can't be represented "
"with the set finn datatype ({}), they will be rounded to match the "
"finn datatype.".format(tensor, dtype)
"with the set FINN datatype ({}), they will be rounded to match the "
"FINN datatype.".format(tensor, dtype)
)
# check if rounded values are not too far from original values
max_error = max(np.abs(current_values - updated_values).flatten())
......@@ -303,7 +320,7 @@ def validate_quant_values(model, node_tensors, execution_context, check_values=F
execution_context[tensor] = updated_values
else:
raise Exception(
"""Rounding error is too high to match set finn
"""Rounding error is too high to match set FINN
datatype ({}) for input {}""".format(
dtype, tensor
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment