diff --git a/tests/brevitas/test_brevitas_debug.py b/tests/brevitas/test_brevitas_debug.py index 7d2a9b6e5e01acfc218b6d8b6d0a8a0e73d7897d..84fcf4a3827c9f5576a8292ef4719b3b0fb1dfe0 100644 --- a/tests/brevitas/test_brevitas_debug.py +++ b/tests/brevitas/test_brevitas_debug.py @@ -41,6 +41,7 @@ from finn.transformation.fold_constants import FoldConstants from finn.transformation.general import RemoveStaticGraphInputs from finn.transformation.infer_shapes import InferShapes from finn.util.test import get_test_model_trained +from finn.transformation.double_to_single_float import DoubleToSingleFloat def test_brevitas_debug(): @@ -49,6 +50,7 @@ def test_brevitas_debug(): dbg_hook = bo.enable_debug(fc) bo.export_finn_onnx(fc, (1, 1, 28, 28), finn_onnx) model = ModelWrapper(finn_onnx) + model = model.transform(DoubleToSingleFloat()) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) model = model.transform(RemoveStaticGraphInputs()) @@ -75,5 +77,5 @@ def test_brevitas_debug(): for dbg_name in names_common: tensor_pytorch = dbg_hook.values[dbg_name].detach().numpy() tensor_finn = output_dict[dbg_name] - assert np.isclose(tensor_finn, tensor_pytorch).all() + assert np.isclose(tensor_finn, tensor_pytorch, atol=1e-5).all() os.remove(finn_onnx)