From 6cbc7b04d5916bd66bae3c7365ad7c9af7798005 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 27 Aug 2020 16:01:11 +0200
Subject: [PATCH] [Debug] move to new Brevitas debug infra, remove hook from
 FINN

---
 src/finn/util/pytorch.py              | 15 ---------------
 tests/brevitas/test_brevitas_debug.py | 12 ++++++------
 2 files changed, 6 insertions(+), 21 deletions(-)

diff --git a/src/finn/util/pytorch.py b/src/finn/util/pytorch.py
index 8332757ca..f174c2460 100644
--- a/src/finn/util/pytorch.py
+++ b/src/finn/util/pytorch.py
@@ -28,7 +28,6 @@
 import torch
 
 from torch.nn import Module, Sequential
-from brevitas.quant_tensor import QuantTensor
 
 
 class Normalize(Module):
@@ -65,17 +64,3 @@ class NormalizePreProc(Module):
 
     def forward(self, x):
         return self.features(x)
-
-
-class BrevitasDebugHook:
-    def __init__(self):
-        self.outputs = {}
-
-    def __call__(self, module, module_in, module_out):
-        tensor = module_out
-        if isinstance(module_out, QuantTensor):
-            tensor = module_out[0]
-        self.outputs[module.export_debug_name] = tensor.detach().numpy()
-
-    def clear(self):
-        self.outputs = {}
diff --git a/tests/brevitas/test_brevitas_debug.py b/tests/brevitas/test_brevitas_debug.py
index cb7bb5a16..7d2a9b6e5 100644
--- a/tests/brevitas/test_brevitas_debug.py
+++ b/tests/brevitas/test_brevitas_debug.py
@@ -41,14 +41,12 @@ from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import RemoveStaticGraphInputs
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.test import get_test_model_trained
-from finn.util.pytorch import BrevitasDebugHook
 
 
 def test_brevitas_debug():
     finn_onnx = "test_brevitas_debug.onnx"
     fc = get_test_model_trained("TFC", 2, 2)
-    dbg_hook = BrevitasDebugHook()
-    bo.enable_debug(fc, dbg_hook)
+    dbg_hook = bo.enable_debug(fc)
     bo.export_finn_onnx(fc, (1, 1, 28, 28), finn_onnx)
     model = ModelWrapper(finn_onnx)
     model = model.transform(InferShapes())
@@ -70,10 +68,12 @@ def test_brevitas_debug():
     expected = fc.forward(input_tensor).detach().numpy()
     assert np.isclose(produced, expected, atol=1e-3).all()
     # check all tensors at debug markers
-    names_brevitas = set(dbg_hook.outputs.keys())
+    names_brevitas = set(dbg_hook.values.keys())
     names_finn = set(output_dict.keys())
     names_common = names_brevitas.intersection(names_finn)
-    assert len(names_common) == 8
+    assert len(names_common) == 16
     for dbg_name in names_common:
-        assert (dbg_hook.outputs[dbg_name] == output_dict[dbg_name]).all()
+        tensor_pytorch = dbg_hook.values[dbg_name].detach().numpy()
+        tensor_finn = output_dict[dbg_name]
+        assert np.isclose(tensor_finn, tensor_pytorch).all()
     os.remove(finn_onnx)
-- 
GitLab