diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh
index 003cd515b4f238fcf4835eaeb545e1de3b5db4f0..6ca0e7164e82b62719d52097e8c6f7960d341ccb 100644
--- a/docker/finn_entrypoint.sh
+++ b/docker/finn_entrypoint.sh
@@ -12,7 +12,7 @@ gecho () {
 
 # checkout the correct dependency repo commits
 # the repos themselves are cloned in the Dockerfile
-BREVITAS_COMMIT=172e423164402a07826877fa9730063bee10a208
+BREVITAS_COMMIT=85e28ac2e6570e91216d042212a7d1a28ec6e394
 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
 HLSLIB_COMMIT=cfafe11a93b79ab1af7529d68f08886913a6466e
 PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f
diff --git a/src/finn/util/pytorch.py b/src/finn/util/pytorch.py
index 8332757cab839c8ce2fe7afa2449da5782d1aea3..f174c24601578cf827cb0da770f29889344e62b8 100644
--- a/src/finn/util/pytorch.py
+++ b/src/finn/util/pytorch.py
@@ -28,7 +28,6 @@
 import torch
 
 from torch.nn import Module, Sequential
-from brevitas.quant_tensor import QuantTensor
 
 
 class Normalize(Module):
@@ -65,17 +64,3 @@ class NormalizePreProc(Module):
 
     def forward(self, x):
         return self.features(x)
-
-
-class BrevitasDebugHook:
-    def __init__(self):
-        self.outputs = {}
-
-    def __call__(self, module, module_in, module_out):
-        tensor = module_out
-        if isinstance(module_out, QuantTensor):
-            tensor = module_out[0]
-        self.outputs[module.export_debug_name] = tensor.detach().numpy()
-
-    def clear(self):
-        self.outputs = {}
diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index e78812b21a03baa6963f1f0efaefdb4c73e4d0db..9112ae7ef026c637e7e5a6375f8e9d17b918d63b 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -45,7 +45,7 @@ def test_brevitas_avg_pool_export(
     scale = np.ones((1, channels, 1, 1))
     output_scale = torch.from_numpy(scale).float()
     input_quant_tensor = pack_quant_tensor(
-        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
+        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed
     )
     bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
     model = ModelWrapper(export_onnx_path)
@@ -65,7 +65,7 @@ def test_brevitas_avg_pool_export(
     inp = gen_finn_dt_tensor(dtype, ishape)
     input_tensor = torch.from_numpy(inp).float()
     input_quant_tensor = pack_quant_tensor(
-        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
+        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed
     )
     b_avgpool.eval()
     expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
@@ -84,7 +84,7 @@ def test_brevitas_avg_pool_export(
     input_tensor = torch.from_numpy(inp_tensor).float()
     input_scale = torch.from_numpy(scale).float()
     input_quant_tensor = pack_quant_tensor(
-        tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor
+        tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor, signed=signed
     )
     # export again to set the scale values correctly
     bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
diff --git a/tests/brevitas/test_brevitas_debug.py b/tests/brevitas/test_brevitas_debug.py
index cb7bb5a16a76e37275ab267c7bf90a4409a8769d..7d2a9b6e5e01acfc218b6d8b6d0a8a0e73d7897d 100644
--- a/tests/brevitas/test_brevitas_debug.py
+++ b/tests/brevitas/test_brevitas_debug.py
@@ -41,14 +41,12 @@ from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import RemoveStaticGraphInputs
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.test import get_test_model_trained
-from finn.util.pytorch import BrevitasDebugHook
 
 
 def test_brevitas_debug():
     finn_onnx = "test_brevitas_debug.onnx"
     fc = get_test_model_trained("TFC", 2, 2)
-    dbg_hook = BrevitasDebugHook()
-    bo.enable_debug(fc, dbg_hook)
+    dbg_hook = bo.enable_debug(fc)
     bo.export_finn_onnx(fc, (1, 1, 28, 28), finn_onnx)
     model = ModelWrapper(finn_onnx)
     model = model.transform(InferShapes())
@@ -70,10 +68,12 @@ def test_brevitas_debug():
     expected = fc.forward(input_tensor).detach().numpy()
     assert np.isclose(produced, expected, atol=1e-3).all()
     # check all tensors at debug markers
-    names_brevitas = set(dbg_hook.outputs.keys())
+    names_brevitas = set(dbg_hook.values.keys())
     names_finn = set(output_dict.keys())
     names_common = names_brevitas.intersection(names_finn)
-    assert len(names_common) == 8
+    assert len(names_common) == 16
     for dbg_name in names_common:
-        assert (dbg_hook.outputs[dbg_name] == output_dict[dbg_name]).all()
+        tensor_pytorch = dbg_hook.values[dbg_name].detach().numpy()
+        tensor_finn = output_dict[dbg_name]
+        assert np.isclose(tensor_finn, tensor_pytorch).all()
     os.remove(finn_onnx)
diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py
index 77974dacb51aa8746ce33f9a490becd35390db5a..fa114585d31fca629aa759e386aa3fbd04280a2e 100644
--- a/tests/brevitas/test_brevitas_relu_act_export.py
+++ b/tests/brevitas/test_brevitas_relu_act_export.py
@@ -15,7 +15,7 @@ from finn.transformation.infer_shapes import InferShapes
 export_onnx_path = "test_brevitas_relu_act_export.onnx"
 
 
-@pytest.mark.parametrize("abits", [1, 2, 4, 8])
+@pytest.mark.parametrize("abits", [2, 4, 8])
 @pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
 @pytest.mark.parametrize(
     "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
@@ -70,7 +70,7 @@ scaling_impl.learned_value": torch.tensor(
     os.remove(export_onnx_path)
 
 
-@pytest.mark.parametrize("abits", [1, 2, 4, 8])
+@pytest.mark.parametrize("abits", [2, 4, 8])
 @pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
 @pytest.mark.parametrize("scaling_per_channel", [True, False])
 def test_brevitas_act_export_relu_imagenet(abits, max_val, scaling_per_channel):