From e5dcac88c4fdd26dd64bc4b5eba269c4d7d7682b Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 20 May 2020 13:56:26 +0100 Subject: [PATCH] [Test] Add test for QuantRelu: per_channel_broadcastable_shape is not None but scaling_per_channel is False --- .../brevitas/test_brevitas_relu_act_export.py | 50 ++++--------------- 1 file changed, 9 insertions(+), 41 deletions(-) diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py index 25c44f0cc..c5ddad12c 100644 --- a/tests/brevitas/test_brevitas_relu_act_export.py +++ b/tests/brevitas/test_brevitas_relu_act_export.py @@ -70,9 +70,10 @@ scaling_impl.learned_value": torch.tensor( os.remove(export_onnx_path) -@pytest.mark.parametrize("abits", [4]) +@pytest.mark.parametrize("abits", [1, 2, 4, 8]) @pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)]) -def test_brevitas_act_export_relu_imagenet(abits, max_val): +@pytest.mark.parametrize("scaling_per_channel", [True, False]) +def test_brevitas_act_export_relu_imagenet(abits, max_val, scaling_per_channel): out_channels = 32 ishape = (1, out_channels, 1, 1) min_val = -1.0 @@ -80,53 +81,20 @@ def test_brevitas_act_export_relu_imagenet(abits, max_val): bit_width=abits, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER, - scaling_per_channel=True, + scaling_per_channel=scaling_per_channel, restrict_scaling_type=RestrictValueType.LOG_FP, scaling_min_val=2e-16, max_val=6.0, return_quant_tensor=True, per_channel_broadcastable_shape=(1, out_channels, 1, 1), ) + if scaling_per_channel is True: + rand_tensor = (2) * torch.rand((1, out_channels, 1, 1)) + else: + rand_tensor = torch.tensor(1.2398) checkpoint = { "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\ -scaling_impl.learned_value": torch.tensor( - [ - [ - [[0.8520]], - [[0.7784]], - [[0.8643]], - [[1.0945]], - [[0.8520]], - [[0.7239]], - [[0.8520]], - [[1.0609]], - [[1.0299]], - [[0.0121]], - [[0.8520]], - [[0.8489]], - [[0.1376]], - [[0.8520]], - [[0.7096]], - [[0.8520]], - [[0.8520]], - [[0.8608]], - [[0.9484]], - [[0.8520]], - [[0.7237]], - [[0.5425]], - [[0.5774]], - [[0.5975]], - [[0.6685]], - [[0.4472]], - [[0.8520]], - [[0.7879]], - [[0.8520]], - [[1.0322]], - [[0.4550]], - [[1.0612]], - ] - ] - ).type( +scaling_impl.learned_value": rand_tensor.type( torch.FloatTensor ) } -- GitLab