diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh
index e34b6ce9cc4488c806da8bcf3cc5cc8e500ae806..65513f3148e0fed2583d02e1eba249bc9a1f2f6e 100644
--- a/docker/finn_entrypoint.sh
+++ b/docker/finn_entrypoint.sh
@@ -13,7 +13,7 @@ gecho () {
 
 # checkout the correct dependency repo commits
 # the repos themselves are cloned in the Dockerfile
-BREVITAS_COMMIT=989cdfdba4700fdd900ba0b25a820591d561c21a
+BREVITAS_COMMIT=7696326e5f279cacffd5b6ac8d9e8d81deec3978
 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
 HLSLIB_COMMIT=13e9b0772a27a3a1efc40c878d8e78ed09efb716
 PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f
diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py
index c9d8f2d812bc7bea1a2fd2598a7711099ad421e6..c5ddad12ca3e8d353682fbb20449d44358485f69 100644
--- a/tests/brevitas/test_brevitas_relu_act_export.py
+++ b/tests/brevitas/test_brevitas_relu_act_export.py
@@ -23,6 +23,7 @@ export_onnx_path = "test_act.onnx"
 def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type):
     min_val = -1.0
     ishape = (1, 15)
+
     b_act = QuantReLU(
         bit_width=abits,
         max_val=max_val,
@@ -67,3 +68,60 @@ scaling_impl.learned_value": torch.tensor(
 
     assert np.isclose(produced, expected, atol=1e-3).all()
     os.remove(export_onnx_path)
+
+
+@pytest.mark.parametrize("abits", [1, 2, 4, 8])
+@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
+@pytest.mark.parametrize("scaling_per_channel", [True, False])
+def test_brevitas_act_export_relu_imagenet(abits, max_val, scaling_per_channel):
+    out_channels = 32
+    ishape = (1, out_channels, 1, 1)
+    min_val = -1.0
+    b_act = QuantReLU(
+        bit_width=abits,
+        quant_type=QuantType.INT,
+        scaling_impl_type=ScalingImplType.PARAMETER,
+        scaling_per_channel=scaling_per_channel,
+        restrict_scaling_type=RestrictValueType.LOG_FP,
+        scaling_min_val=2e-16,
+        max_val=6.0,
+        return_quant_tensor=True,
+        per_channel_broadcastable_shape=(1, out_channels, 1, 1),
+    )
+    if scaling_per_channel is True:
+        rand_tensor = (2) * torch.rand((1, out_channels, 1, 1))
+    else:
+        rand_tensor = torch.tensor(1.2398)
+    checkpoint = {
+        "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
+scaling_impl.learned_value": rand_tensor.type(
+            torch.FloatTensor
+        )
+    }
+    b_act.load_state_dict(checkpoint)
+    bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(InferShapes())
+    inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
+        np.float32
+    )
+    idict = {model.graph.input[0].name: inp_tensor}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+    inp_tensor = torch.from_numpy(inp_tensor).float()
+    b_act.eval()
+    expected = b_act.forward(inp_tensor).tensor.detach().numpy()
+    if not np.isclose(produced, expected, atol=1e-3).all():
+        print(abits, max_val)
+        print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
+        if abits < 5:
+            print(
+                "thres:",
+                ", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]),
+            )
+        print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
+        print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
+        print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
+
+    assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)