From 45726278bc5dd88acc953d7c56fc332108b56d04 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Fri, 11 Sep 2020 10:21:19 +0200 Subject: [PATCH] [Test] fix debug test w Double2SingleFloat, higher atol --- tests/brevitas/test_brevitas_debug.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/brevitas/test_brevitas_debug.py b/tests/brevitas/test_brevitas_debug.py index 7d2a9b6e5..84fcf4a38 100644 --- a/tests/brevitas/test_brevitas_debug.py +++ b/tests/brevitas/test_brevitas_debug.py @@ -41,6 +41,7 @@ from finn.transformation.fold_constants import FoldConstants from finn.transformation.general import RemoveStaticGraphInputs from finn.transformation.infer_shapes import InferShapes from finn.util.test import get_test_model_trained +from finn.transformation.double_to_single_float import DoubleToSingleFloat def test_brevitas_debug(): @@ -49,6 +50,7 @@ def test_brevitas_debug(): dbg_hook = bo.enable_debug(fc) bo.export_finn_onnx(fc, (1, 1, 28, 28), finn_onnx) model = ModelWrapper(finn_onnx) + model = model.transform(DoubleToSingleFloat()) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) model = model.transform(RemoveStaticGraphInputs()) @@ -75,5 +77,5 @@ def test_brevitas_debug(): for dbg_name in names_common: tensor_pytorch = dbg_hook.values[dbg_name].detach().numpy() tensor_finn = output_dict[dbg_name] - assert np.isclose(tensor_finn, tensor_pytorch).all() + assert np.isclose(tensor_finn, tensor_pytorch, atol=1e-5).all() os.remove(finn_onnx) -- GitLab