Skip to content
Snippets Groups Projects
Commit 6fe1f90e authored by auphelia's avatar auphelia
Browse files

[Util] Update fct to validate quant values (update execution context)

parent 1bdab22f
No related branches found
No related tags found
No related merge requests found
......@@ -265,12 +265,16 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len):
return (min_prod, max_prod)
def update_execution_context(model, node, execution_context):
for inp in node.input:
dtype = model.get_tensor_datatype(inp)
current_values = execution_context[inp]
def validate_quant_values(model, node_tensors, execution_context, check_values=False):
for tensor in node_tensors:
dtype = model.get_tensor_datatype(tensor)
# if FLOAT32 skip to next tensor
if dtype == DataType.FLOAT32:
break
current_values = execution_context[tensor]
updated_values = current_values
has_to_be_rounded = False
# TODO: vectorize with numpy
for value in np.nditer(current_values):
if not dtype.allowed(value):
has_to_be_rounded = True
......@@ -280,26 +284,28 @@ def update_execution_context(model, node, execution_context):
warnings.warn(
"The values of tensor {} can't be represented "
"with the set finn datatype ({}), they will be rounded to match the "
"finn datatype.".format(inp, dtype)
"finn datatype.".format(tensor, dtype)
)
# check if rounded values are not too far from original values
max_error = max(np.abs(current_values - updated_values).flatten())
if max_error <= 1e-4:
# check again if values can now be represented with set finn datatype
for value in np.nditer(updated_values):
if not dtype.allowed(value):
raise Exception(
"""Values can't be represented with set
finn datatype ({}) for input {}""".format(
dtype, inp
if check_values is True:
# check again if values can now be represented with set finn datatype
# TODO: vectorize with numpy
for value in np.nditer(updated_values):
if not dtype.allowed(value):
raise Exception(
"""Values can't be represented with set
finn datatype ({}) for input {}""".format(
dtype, tensor
)
)
)
execution_context[inp] = updated_values
execution_context[tensor] = updated_values
else:
raise Exception(
"""Values can't be rounded to match set finn
"""Rounding error is too high to match set finn
datatype ({}) for input {}""".format(
dtype, inp
dtype, tensor
)
)
return execution_context
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment