diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py index 08d92b3881911b9feb308c76f0bbde4ada70e115..c4154aad41cc33e3176b24293676d33168bca2bc 100644 --- a/src/finn/util/basic.py +++ b/src/finn/util/basic.py @@ -265,12 +265,16 @@ def calculate_signed_dot_prod_range(dt_a, dt_b, len): return (min_prod, max_prod) -def update_execution_context(model, node, execution_context): - for inp in node.input: - dtype = model.get_tensor_datatype(inp) - current_values = execution_context[inp] +def validate_quant_values(model, node_tensors, execution_context, check_values=False): + for tensor in node_tensors: + dtype = model.get_tensor_datatype(tensor) + # if FLOAT32 skip to next tensor + if dtype == DataType.FLOAT32: + break + current_values = execution_context[tensor] updated_values = current_values has_to_be_rounded = False + # TODO: vectorize with numpy for value in np.nditer(current_values): if not dtype.allowed(value): has_to_be_rounded = True @@ -280,26 +284,28 @@ def update_execution_context(model, node, execution_context): warnings.warn( "The values of tensor {} can't be represented " "with the set finn datatype ({}), they will be rounded to match the " - "finn datatype.".format(inp, dtype) + "finn datatype.".format(tensor, dtype) ) # check if rounded values are not too far from original values max_error = max(np.abs(current_values - updated_values).flatten()) if max_error <= 1e-4: - # check again if values can now be represented with set finn datatype - for value in np.nditer(updated_values): - if not dtype.allowed(value): - raise Exception( - """Values can't be represented with set - finn datatype ({}) for input {}""".format( - dtype, inp + if check_values is True: + # check again if values can now be represented with set finn datatype + # TODO: vectorize with numpy + for value in np.nditer(updated_values): + if not dtype.allowed(value): + raise Exception( + """Values can't be represented with set + finn datatype ({}) for input {}""".format( + dtype, tensor + ) ) - ) - execution_context[inp] = updated_values + execution_context[tensor] = updated_values else: raise Exception( - """Values can't be rounded to match set finn + """Rounding error is too high to match set finn datatype ({}) for input {}""".format( - dtype, inp + dtype, tensor ) ) return execution_context