diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py index 585d6b90a59ca9f3dac56a6133de705c2f56f025..809b34157ee4b7890a4155bd09f33dcc85c6ceec 100644 --- a/src/finn/util/basic.py +++ b/src/finn/util/basic.py @@ -106,6 +106,14 @@ def get_finn_root(): ) +def get_execution_error_thresh(): + "Return the max error that is allowed for rounding in FINN execution." + try: + return float(os.environ["ERROR_THRESH"]) + except KeyError: + return 1e-2 + + def make_build_dir(prefix=""): """Creates a temporary folder with given prefix to be used as a build dir. Use this function instead of tempfile.mkdtemp to ensure any generated files @@ -305,7 +313,7 @@ def sanitize_quant_values(model, node_tensors, execution_context, check_values=F ) # check if rounded values are not too far from original values max_error = max(np.abs(current_values - updated_values).flatten()) - if max_error <= 1e-4: + if max_error <= get_execution_error_thresh(): if check_values is True: # check again if values can now be represented with set finn datatype # TODO: vectorize with numpy diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index 34b154c4e99def3f27041cb1cb4766be5c0f222e..24854a2153df9af78feb8352ca119e831a9ac9eb 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -55,10 +55,9 @@ def test_brevitas_avg_pool_export( prefix = "INT" else: prefix = "UINT" - dt_name = prefix + str(input_bit_width) + dt_name = prefix + str(input_bit_width // 2) dtype = DataType[dt_name] model = model.transform(InferShapes()) - model.set_tensor_datatype(model.graph.input[0].name, dtype) model = model.transform(InferDataTypes()) # execution with input tensor using integers and scale = 1 @@ -91,6 +90,7 @@ def test_brevitas_avg_pool_export( bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() # finn execution