diff --git a/tests/brevitas/test_brevitas_act_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
similarity index 94%
rename from tests/brevitas/test_brevitas_act_export.py
rename to tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
index e1cfc9db98a9c9d746b6f66ee071ddfb85cc5dbb..b66348a9902802bc65b2a35e8bc3e311cc81e0bc 100644
--- a/tests/brevitas/test_brevitas_act_export.py
+++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
@@ -1,15 +1,16 @@
 import os
+import onnx  # noqa
 import numpy as np
 import torch
 import brevitas.onnx as bo
 from brevitas.nn import QuantHardTanh
 from brevitas.core.restrict_val import RestrictValueType
-from brevitas.core.quant import QuantType
 from brevitas.core.scaling import ScalingImplType
 import pytest
 from finn.core.modelwrapper import ModelWrapper
 import finn.core.onnx_exec as oxe
 from finn.transformation.infer_shapes import InferShapes
+from brevitas.core.quant import QuantType
 
 export_onnx_path = "test_act.onnx"
 
@@ -17,7 +18,7 @@ export_onnx_path = "test_act.onnx"
 @pytest.mark.parametrize("abits", [1, 2, 4, 8])
 @pytest.mark.parametrize("narrow_range", [False, True])
 @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)])
-def test_brevitas_act_export(abits, narrow_range, max_val):
+def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val):
     def get_quant_type(bit_width):
         if bit_width is None:
             return QuantType.FP
diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9d8f2d812bc7bea1a2fd2598a7711099ad421e6
--- /dev/null
+++ b/tests/brevitas/test_brevitas_relu_act_export.py
@@ -0,0 +1,69 @@
+import os
+import onnx  # noqa
+import numpy as np
+import torch
+import brevitas.onnx as bo
+from brevitas.nn import QuantReLU
+from brevitas.core.quant import QuantType
+from brevitas.core.restrict_val import RestrictValueType
+from brevitas.core.scaling import ScalingImplType
+import pytest
+from finn.core.modelwrapper import ModelWrapper
+import finn.core.onnx_exec as oxe
+from finn.transformation.infer_shapes import InferShapes
+
+export_onnx_path = "test_act.onnx"
+
+
+@pytest.mark.parametrize("abits", [1, 2, 4, 8])
+@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
+@pytest.mark.parametrize(
+    "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
+)
+def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type):
+    min_val = -1.0
+    ishape = (1, 15)
+    b_act = QuantReLU(
+        bit_width=abits,
+        max_val=max_val,
+        scaling_impl_type=scaling_impl_type,
+        restrict_scaling_type=RestrictValueType.LOG_FP,
+        quant_type=QuantType.INT,
+    )
+    if scaling_impl_type == ScalingImplType.PARAMETER:
+        checkpoint = {
+            "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
+scaling_impl.learned_value": torch.tensor(
+                0.49
+            ).type(
+                torch.FloatTensor
+            )
+        }
+        b_act.load_state_dict(checkpoint)
+
+    bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(InferShapes())
+    inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
+        np.float32
+    )
+    idict = {model.graph.input[0].name: inp_tensor}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+    inp_tensor = torch.from_numpy(inp_tensor).float()
+    b_act.eval()
+    expected = b_act.forward(inp_tensor).detach().numpy()
+    if not np.isclose(produced, expected, atol=1e-3).all():
+        print(abits, max_val, scaling_impl_type)
+        print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
+        if abits < 5:
+            print(
+                "thres:",
+                ", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]),
+            )
+        print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
+        print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
+        print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
+
+    assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)
diff --git a/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..d499f1517341477eca9915245da9ad12c346c5a9
--- /dev/null
+++ b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py
@@ -0,0 +1,93 @@
+import onnx  # noqa
+import os
+import numpy as np
+import torch
+import brevitas.onnx as bo
+from brevitas.nn import QuantHardTanh
+from brevitas.core.restrict_val import RestrictValueType
+from brevitas.core.quant import QuantType
+from brevitas.core.scaling import ScalingImplType
+import pytest
+from finn.core.modelwrapper import ModelWrapper
+import finn.core.onnx_exec as oxe
+from finn.transformation.infer_shapes import InferShapes
+
+export_onnx_path = "test_act.onnx"
+
+
+@pytest.mark.parametrize("abits", [2, 4, 8])
+@pytest.mark.parametrize("narrow_range", [False, True])
+@pytest.mark.parametrize("min_val", [-1.0, -(1 - 2 ** (-7)), -2])
+@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7), 2])
+@pytest.mark.parametrize(
+    "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
+)
+def test_brevitas_act_export_qhardtanh_scaled(
+    abits, narrow_range, min_val, max_val, scaling_impl_type
+):
+    def get_quant_type(bit_width):
+        if bit_width is None:
+            return QuantType.FP
+        elif bit_width == 1:
+            return QuantType.BINARY
+        else:
+            return QuantType.INT
+
+    act_quant_type = get_quant_type(abits)
+    ishape = (1, 15)
+    b_act = QuantHardTanh(
+        bit_width=abits,
+        quant_type=act_quant_type,
+        max_val=max_val,
+        min_val=min_val,
+        restrict_scaling_type=RestrictValueType.LOG_FP,
+        scaling_impl_type=scaling_impl_type,
+        narrow_range=narrow_range,
+    )
+    if scaling_impl_type == ScalingImplType.PARAMETER:
+        checkpoint = {
+            "act_quant_proxy.fused_activation_quant_proxy.\
+tensor_quant.scaling_impl.learned_value": torch.tensor(
+                0.49
+            ).type(
+                torch.FloatTensor
+            )
+        }
+        b_act.load_state_dict(checkpoint)
+
+    bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(InferShapes())
+    inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
+        np.float32
+    )
+    idict = {model.graph.input[0].name: inp_tensor}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+    inp_tensor = torch.from_numpy(inp_tensor).float()
+    b_act.eval()
+    expected = b_act.forward(inp_tensor).detach().numpy()
+    if not np.isclose(produced, expected, atol=1e-3).all():
+        print(
+            "abits: ",
+            abits,
+            " | narrow_range: ",
+            narrow_range,
+            " | min_val: ",
+            min_val,
+            " | max_val: ",
+            max_val,
+        )
+        print("layer scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
+        print("export scale: ", b_act.export_act_scale)
+        if abits < 5:
+            print(
+                "thres:",
+                ", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]),
+            )
+        print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
+        print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
+        print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
+
+    assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)