diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py index c9d8f2d812bc7bea1a2fd2598a7711099ad421e6..eb56abc28ed0cdd104f2826b24668ec0cd166313 100644 --- a/tests/brevitas/test_brevitas_relu_act_export.py +++ b/tests/brevitas/test_brevitas_relu_act_export.py @@ -23,6 +23,7 @@ export_onnx_path = "test_act.onnx" def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type): min_val = -1.0 ishape = (1, 15) + b_act = QuantReLU( bit_width=abits, max_val=max_val, @@ -67,3 +68,66 @@ scaling_impl.learned_value": torch.tensor( assert np.isclose(produced, expected, atol=1e-3).all() os.remove(export_onnx_path) + + +@pytest.mark.parametrize("abits", [4]) +@pytest.mark.parametrize("out_channels", [32]) +def test_brevitas_act_export_relu_imagenet(abits, out_channels): + ishape = (32, 15) + + b_act = QuantReLU( + bit_width=abits, + quant_type=QuantType.INT, + scaling_impl_type=ScalingImplType.PARAMETER, + scaling_per_channel=True, + restrict_scaling_type=RestrictValueType.LOG_FP, + scaling_min_val=2e-16, + max_val=6.0, + return_quant_tensor=False, + per_channel_broadcastable_shape=(1, out_channels, 1, 1), + ) + checkpoint = { + "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\ +scaling_impl.learned_value": torch.tensor( + [ + [ + [[0.8520]], + [[0.7784]], + [[0.8643]], + [[1.0945]], + [[0.8520]], + [[0.7239]], + [[0.8520]], + [[1.0609]], + [[1.0299]], + [[0.0121]], + [[0.8520]], + [[0.8489]], + [[0.1376]], + [[0.8520]], + [[0.7096]], + [[0.8520]], + [[0.8520]], + [[0.8608]], + [[0.9484]], + [[0.8520]], + [[0.7237]], + [[0.5425]], + [[0.5774]], + [[0.5975]], + [[0.6685]], + [[0.4472]], + [[0.8520]], + [[0.7879]], + [[0.8520]], + [[1.0322]], + [[0.4550]], + [[1.0612]], + ] + ] + ).type( + torch.FloatTensor + ) + } + b_act.load_state_dict(checkpoint) + bo.export_finn_onnx(b_act, ishape, export_onnx_path)