Skip to content
Snippets Groups Projects
Commit 6fe1f90e authored by auphelia's avatar auphelia
Browse files

[Util] Update fct to validate quant values (update execution context)

parent 1bdab22f
No related branches found
No related tags found
No related merge requests found
...@@ -265,12 +265,16 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len): ...@@ -265,12 +265,16 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len):
return (min_prod, max_prod) return (min_prod, max_prod)
def update_execution_context(model, node, execution_context): def validate_quant_values(model, node_tensors, execution_context, check_values=False):
for inp in node.input: for tensor in node_tensors:
dtype = model.get_tensor_datatype(inp) dtype = model.get_tensor_datatype(tensor)
current_values = execution_context[inp] # if FLOAT32 skip to next tensor
if dtype == DataType.FLOAT32:
break
current_values = execution_context[tensor]
updated_values = current_values updated_values = current_values
has_to_be_rounded = False has_to_be_rounded = False
# TODO: vectorize with numpy
for value in np.nditer(current_values): for value in np.nditer(current_values):
if not dtype.allowed(value): if not dtype.allowed(value):
has_to_be_rounded = True has_to_be_rounded = True
...@@ -280,26 +284,28 @@ def update_execution_context(model, node, execution_context): ...@@ -280,26 +284,28 @@ def update_execution_context(model, node, execution_context):
warnings.warn( warnings.warn(
"The values of tensor {} can't be represented " "The values of tensor {} can't be represented "
"with the set finn datatype ({}), they will be rounded to match the " "with the set finn datatype ({}), they will be rounded to match the "
"finn datatype.".format(inp, dtype) "finn datatype.".format(tensor, dtype)
) )
# check if rounded values are not too far from original values # check if rounded values are not too far from original values
max_error = max(np.abs(current_values - updated_values).flatten()) max_error = max(np.abs(current_values - updated_values).flatten())
if max_error <= 1e-4: if max_error <= 1e-4:
# check again if values can now be represented with set finn datatype if check_values is True:
for value in np.nditer(updated_values): # check again if values can now be represented with set finn datatype
if not dtype.allowed(value): # TODO: vectorize with numpy
raise Exception( for value in np.nditer(updated_values):
"""Values can't be represented with set if not dtype.allowed(value):
finn datatype ({}) for input {}""".format( raise Exception(
dtype, inp """Values can't be represented with set
finn datatype ({}) for input {}""".format(
dtype, tensor
)
) )
) execution_context[tensor] = updated_values
execution_context[inp] = updated_values
else: else:
raise Exception( raise Exception(
"""Values can't be rounded to match set finn """Rounding error is too high to match set finn
datatype ({}) for input {}""".format( datatype ({}) for input {}""".format(
dtype, inp dtype, tensor
) )
) )
return execution_context return execution_context
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment