From e5dcac88c4fdd26dd64bc4b5eba269c4d7d7682b Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 20 May 2020 13:56:26 +0100
Subject: [PATCH] [Test] Add test for QuantRelu:
 per_channel_broadcastable_shape is not None but scaling_per_channel is False

---
 .../brevitas/test_brevitas_relu_act_export.py | 50 ++++---------------
 1 file changed, 9 insertions(+), 41 deletions(-)

diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py
index 25c44f0cc..c5ddad12c 100644
--- a/tests/brevitas/test_brevitas_relu_act_export.py
+++ b/tests/brevitas/test_brevitas_relu_act_export.py
@@ -70,9 +70,10 @@ scaling_impl.learned_value": torch.tensor(
     os.remove(export_onnx_path)
 
 
-@pytest.mark.parametrize("abits", [4])
+@pytest.mark.parametrize("abits", [1, 2, 4, 8])
 @pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
-def test_brevitas_act_export_relu_imagenet(abits, max_val):
+@pytest.mark.parametrize("scaling_per_channel", [True, False])
+def test_brevitas_act_export_relu_imagenet(abits, max_val, scaling_per_channel):
     out_channels = 32
     ishape = (1, out_channels, 1, 1)
     min_val = -1.0
@@ -80,53 +81,20 @@ def test_brevitas_act_export_relu_imagenet(abits, max_val):
         bit_width=abits,
         quant_type=QuantType.INT,
         scaling_impl_type=ScalingImplType.PARAMETER,
-        scaling_per_channel=True,
+        scaling_per_channel=scaling_per_channel,
         restrict_scaling_type=RestrictValueType.LOG_FP,
         scaling_min_val=2e-16,
         max_val=6.0,
         return_quant_tensor=True,
         per_channel_broadcastable_shape=(1, out_channels, 1, 1),
     )
+    if scaling_per_channel is True:
+        rand_tensor = (2) * torch.rand((1, out_channels, 1, 1))
+    else:
+        rand_tensor = torch.tensor(1.2398)
     checkpoint = {
         "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
-scaling_impl.learned_value": torch.tensor(
-            [
-                [
-                    [[0.8520]],
-                    [[0.7784]],
-                    [[0.8643]],
-                    [[1.0945]],
-                    [[0.8520]],
-                    [[0.7239]],
-                    [[0.8520]],
-                    [[1.0609]],
-                    [[1.0299]],
-                    [[0.0121]],
-                    [[0.8520]],
-                    [[0.8489]],
-                    [[0.1376]],
-                    [[0.8520]],
-                    [[0.7096]],
-                    [[0.8520]],
-                    [[0.8520]],
-                    [[0.8608]],
-                    [[0.9484]],
-                    [[0.8520]],
-                    [[0.7237]],
-                    [[0.5425]],
-                    [[0.5774]],
-                    [[0.5975]],
-                    [[0.6685]],
-                    [[0.4472]],
-                    [[0.8520]],
-                    [[0.7879]],
-                    [[0.8520]],
-                    [[1.0322]],
-                    [[0.4550]],
-                    [[1.0612]],
-                ]
-            ]
-        ).type(
+scaling_impl.learned_value": rand_tensor.type(
             torch.FloatTensor
         )
     }
-- 
GitLab