Skip to content
Snippets Groups Projects
Commit 6cbc7b04 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Debug] move to new Brevitas debug infra, remove hook from FINN

parent ec4a56e9
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,6 @@
import torch
from torch.nn import Module, Sequential
from brevitas.quant_tensor import QuantTensor
class Normalize(Module):
......@@ -65,17 +64,3 @@ class NormalizePreProc(Module):
def forward(self, x):
return self.features(x)
class BrevitasDebugHook:
def __init__(self):
self.outputs = {}
def __call__(self, module, module_in, module_out):
tensor = module_out
if isinstance(module_out, QuantTensor):
tensor = module_out[0]
self.outputs[module.export_debug_name] = tensor.detach().numpy()
def clear(self):
self.outputs = {}
......@@ -41,14 +41,12 @@ from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import RemoveStaticGraphInputs
from finn.transformation.infer_shapes import InferShapes
from finn.util.test import get_test_model_trained
from finn.util.pytorch import BrevitasDebugHook
def test_brevitas_debug():
finn_onnx = "test_brevitas_debug.onnx"
fc = get_test_model_trained("TFC", 2, 2)
dbg_hook = BrevitasDebugHook()
bo.enable_debug(fc, dbg_hook)
dbg_hook = bo.enable_debug(fc)
bo.export_finn_onnx(fc, (1, 1, 28, 28), finn_onnx)
model = ModelWrapper(finn_onnx)
model = model.transform(InferShapes())
......@@ -70,10 +68,12 @@ def test_brevitas_debug():
expected = fc.forward(input_tensor).detach().numpy()
assert np.isclose(produced, expected, atol=1e-3).all()
# check all tensors at debug markers
names_brevitas = set(dbg_hook.outputs.keys())
names_brevitas = set(dbg_hook.values.keys())
names_finn = set(output_dict.keys())
names_common = names_brevitas.intersection(names_finn)
assert len(names_common) == 8
assert len(names_common) == 16
for dbg_name in names_common:
assert (dbg_hook.outputs[dbg_name] == output_dict[dbg_name]).all()
tensor_pytorch = dbg_hook.values[dbg_name].detach().numpy()
tensor_finn = output_dict[dbg_name]
assert np.isclose(tensor_finn, tensor_pytorch).all()
os.remove(finn_onnx)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment