diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py
index c9d8f2d812bc7bea1a2fd2598a7711099ad421e6..eb56abc28ed0cdd104f2826b24668ec0cd166313 100644
--- a/tests/brevitas/test_brevitas_relu_act_export.py
+++ b/tests/brevitas/test_brevitas_relu_act_export.py
@@ -23,6 +23,7 @@ export_onnx_path = "test_act.onnx"
 def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type):
     min_val = -1.0
     ishape = (1, 15)
+
     b_act = QuantReLU(
         bit_width=abits,
         max_val=max_val,
@@ -67,3 +68,66 @@ scaling_impl.learned_value": torch.tensor(
 
     assert np.isclose(produced, expected, atol=1e-3).all()
     os.remove(export_onnx_path)
+
+
+@pytest.mark.parametrize("abits", [4])
+@pytest.mark.parametrize("out_channels", [32])
+def test_brevitas_act_export_relu_imagenet(abits, out_channels):
+    ishape = (32, 15)
+
+    b_act = QuantReLU(
+        bit_width=abits,
+        quant_type=QuantType.INT,
+        scaling_impl_type=ScalingImplType.PARAMETER,
+        scaling_per_channel=True,
+        restrict_scaling_type=RestrictValueType.LOG_FP,
+        scaling_min_val=2e-16,
+        max_val=6.0,
+        return_quant_tensor=False,
+        per_channel_broadcastable_shape=(1, out_channels, 1, 1),
+    )
+    checkpoint = {
+        "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
+scaling_impl.learned_value": torch.tensor(
+            [
+                [
+                    [[0.8520]],
+                    [[0.7784]],
+                    [[0.8643]],
+                    [[1.0945]],
+                    [[0.8520]],
+                    [[0.7239]],
+                    [[0.8520]],
+                    [[1.0609]],
+                    [[1.0299]],
+                    [[0.0121]],
+                    [[0.8520]],
+                    [[0.8489]],
+                    [[0.1376]],
+                    [[0.8520]],
+                    [[0.7096]],
+                    [[0.8520]],
+                    [[0.8520]],
+                    [[0.8608]],
+                    [[0.9484]],
+                    [[0.8520]],
+                    [[0.7237]],
+                    [[0.5425]],
+                    [[0.5774]],
+                    [[0.5975]],
+                    [[0.6685]],
+                    [[0.4472]],
+                    [[0.8520]],
+                    [[0.7879]],
+                    [[0.8520]],
+                    [[1.0322]],
+                    [[0.4550]],
+                    [[1.0612]],
+                ]
+            ]
+        ).type(
+            torch.FloatTensor
+        )
+    }
+    b_act.load_state_dict(checkpoint)
+    bo.export_finn_onnx(b_act, ishape, export_onnx_path)