From c641ed7abcdc258d6769745f30c001ab2de1efdd Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 20 May 2020 11:29:20 +0100 Subject: [PATCH] [Test] Add comparison of pytorch execution an finn onnx execution of new quant relu node --- .../brevitas/test_brevitas_relu_act_export.py | 36 ++++++++++++++++--- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py index eb56abc28..25c44f0cc 100644 --- a/tests/brevitas/test_brevitas_relu_act_export.py +++ b/tests/brevitas/test_brevitas_relu_act_export.py @@ -71,10 +71,11 @@ scaling_impl.learned_value": torch.tensor( @pytest.mark.parametrize("abits", [4]) -@pytest.mark.parametrize("out_channels", [32]) -def test_brevitas_act_export_relu_imagenet(abits, out_channels): - ishape = (32, 15) - +@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)]) +def test_brevitas_act_export_relu_imagenet(abits, max_val): + out_channels = 32 + ishape = (1, out_channels, 1, 1) + min_val = -1.0 b_act = QuantReLU( bit_width=abits, quant_type=QuantType.INT, @@ -83,7 +84,7 @@ def test_brevitas_act_export_relu_imagenet(abits, out_channels): restrict_scaling_type=RestrictValueType.LOG_FP, scaling_min_val=2e-16, max_val=6.0, - return_quant_tensor=False, + return_quant_tensor=True, per_channel_broadcastable_shape=(1, out_channels, 1, 1), ) checkpoint = { @@ -131,3 +132,28 @@ scaling_impl.learned_value": torch.tensor( } b_act.load_state_dict(checkpoint) bo.export_finn_onnx(b_act, ishape, export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( + np.float32 + ) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + b_act.eval() + expected = b_act.forward(inp_tensor).tensor.detach().numpy() + if not np.isclose(produced, expected, atol=1e-3).all(): + print(abits, max_val) + print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach()) + if abits < 5: + print( + "thres:", + ", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]), + ) + print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]])) + print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]])) + print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]])) + + assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_onnx_path) -- GitLab