From 6fe1f90ee77a2cbc23087a9dd451440ac09b876e Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Mon, 15 Jun 2020 10:53:03 +0100 Subject: [PATCH] [Util] Update fct to validate quant values (update execution context) --- src/finn/util/basic.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py index 08d92b388..c4154aad4 100644 --- a/src/finn/util/basic.py +++ b/src/finn/util/basic.py @@ -265,12 +265,16 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len): return (min_prod, max_prod) -def update_execution_context(model, node, execution_context): - for inp in node.input: - dtype = model.get_tensor_datatype(inp) - current_values = execution_context[inp] +def validate_quant_values(model, node_tensors, execution_context, check_values=False): + for tensor in node_tensors: + dtype = model.get_tensor_datatype(tensor) + # if FLOAT32 skip to next tensor + if dtype == DataType.FLOAT32: + break + current_values = execution_context[tensor] updated_values = current_values has_to_be_rounded = False + # TODO: vectorize with numpy for value in np.nditer(current_values): if not dtype.allowed(value): has_to_be_rounded = True @@ -280,26 +284,28 @@ def update_execution_context(model, node, execution_context): warnings.warn( "The values of tensor {} can't be represented " "with the set finn datatype ({}), they will be rounded to match the " - "finn datatype.".format(inp, dtype) + "finn datatype.".format(tensor, dtype) ) # check if rounded values are not too far from original values max_error = max(np.abs(current_values - updated_values).flatten()) if max_error <= 1e-4: - # check again if values can now be represented with set finn datatype - for value in np.nditer(updated_values): - if not dtype.allowed(value): - raise Exception( - """Values can't be represented with set - finn datatype ({}) for input {}""".format( - dtype, inp + if check_values is True: + # check again if values can now be represented with set finn datatype + # TODO: vectorize with numpy + for value in np.nditer(updated_values): + if not dtype.allowed(value): + raise Exception( + """Values can't be represented with set + finn datatype ({}) for input {}""".format( + dtype, tensor + ) ) - ) - execution_context[inp] = updated_values + execution_context[tensor] = updated_values else: raise Exception( - """Values can't be rounded to match set finn + """Rounding error is too high to match set finn datatype ({}) for input {}""".format( - dtype, inp + dtype, tensor ) ) return execution_context -- GitLab