From 45726278bc5dd88acc953d7c56fc332108b56d04 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Fri, 11 Sep 2020 10:21:19 +0200
Subject: [PATCH] [Test] fix debug test w Double2SingleFloat, higher atol

---
 tests/brevitas/test_brevitas_debug.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/tests/brevitas/test_brevitas_debug.py b/tests/brevitas/test_brevitas_debug.py
index 7d2a9b6e5..84fcf4a38 100644
--- a/tests/brevitas/test_brevitas_debug.py
+++ b/tests/brevitas/test_brevitas_debug.py
@@ -41,6 +41,7 @@ from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import RemoveStaticGraphInputs
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.test import get_test_model_trained
+from finn.transformation.double_to_single_float import DoubleToSingleFloat
 
 
 def test_brevitas_debug():
@@ -49,6 +50,7 @@ def test_brevitas_debug():
     dbg_hook = bo.enable_debug(fc)
     bo.export_finn_onnx(fc, (1, 1, 28, 28), finn_onnx)
     model = ModelWrapper(finn_onnx)
+    model = model.transform(DoubleToSingleFloat())
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
     model = model.transform(RemoveStaticGraphInputs())
@@ -75,5 +77,5 @@ def test_brevitas_debug():
     for dbg_name in names_common:
         tensor_pytorch = dbg_hook.values[dbg_name].detach().numpy()
         tensor_finn = output_dict[dbg_name]
-        assert np.isclose(tensor_finn, tensor_pytorch).all()
+        assert np.isclose(tensor_finn, tensor_pytorch, atol=1e-5).all()
     os.remove(finn_onnx)
-- 
GitLab