From 6cbc7b04d5916bd66bae3c7365ad7c9af7798005 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 27 Aug 2020 16:01:11 +0200 Subject: [PATCH] [Debug] move to new Brevitas debug infra, remove hook from FINN --- src/finn/util/pytorch.py | 15 --------------- tests/brevitas/test_brevitas_debug.py | 12 ++++++------ 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/src/finn/util/pytorch.py b/src/finn/util/pytorch.py index 8332757ca..f174c2460 100644 --- a/src/finn/util/pytorch.py +++ b/src/finn/util/pytorch.py @@ -28,7 +28,6 @@ import torch from torch.nn import Module, Sequential -from brevitas.quant_tensor import QuantTensor class Normalize(Module): @@ -65,17 +64,3 @@ class NormalizePreProc(Module): def forward(self, x): return self.features(x) - - -class BrevitasDebugHook: - def __init__(self): - self.outputs = {} - - def __call__(self, module, module_in, module_out): - tensor = module_out - if isinstance(module_out, QuantTensor): - tensor = module_out[0] - self.outputs[module.export_debug_name] = tensor.detach().numpy() - - def clear(self): - self.outputs = {} diff --git a/tests/brevitas/test_brevitas_debug.py b/tests/brevitas/test_brevitas_debug.py index cb7bb5a16..7d2a9b6e5 100644 --- a/tests/brevitas/test_brevitas_debug.py +++ b/tests/brevitas/test_brevitas_debug.py @@ -41,14 +41,12 @@ from finn.transformation.fold_constants import FoldConstants from finn.transformation.general import RemoveStaticGraphInputs from finn.transformation.infer_shapes import InferShapes from finn.util.test import get_test_model_trained -from finn.util.pytorch import BrevitasDebugHook def test_brevitas_debug(): finn_onnx = "test_brevitas_debug.onnx" fc = get_test_model_trained("TFC", 2, 2) - dbg_hook = BrevitasDebugHook() - bo.enable_debug(fc, dbg_hook) + dbg_hook = bo.enable_debug(fc) bo.export_finn_onnx(fc, (1, 1, 28, 28), finn_onnx) model = ModelWrapper(finn_onnx) model = model.transform(InferShapes()) @@ -70,10 +68,12 @@ def test_brevitas_debug(): expected = fc.forward(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() # check all tensors at debug markers - names_brevitas = set(dbg_hook.outputs.keys()) + names_brevitas = set(dbg_hook.values.keys()) names_finn = set(output_dict.keys()) names_common = names_brevitas.intersection(names_finn) - assert len(names_common) == 8 + assert len(names_common) == 16 for dbg_name in names_common: - assert (dbg_hook.outputs[dbg_name] == output_dict[dbg_name]).all() + tensor_pytorch = dbg_hook.values[dbg_name].detach().numpy() + tensor_finn = output_dict[dbg_name] + assert np.isclose(tensor_finn, tensor_pytorch).all() os.remove(finn_onnx) -- GitLab