From c641ed7abcdc258d6769745f30c001ab2de1efdd Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 20 May 2020 11:29:20 +0100
Subject: [PATCH] [Test] Add comparison of pytorch execution an finn onnx
 execution of new quant relu node

---
 .../brevitas/test_brevitas_relu_act_export.py | 36 ++++++++++++++++---
 1 file changed, 31 insertions(+), 5 deletions(-)

diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py
index eb56abc28..25c44f0cc 100644
--- a/tests/brevitas/test_brevitas_relu_act_export.py
+++ b/tests/brevitas/test_brevitas_relu_act_export.py
@@ -71,10 +71,11 @@ scaling_impl.learned_value": torch.tensor(
 
 
 @pytest.mark.parametrize("abits", [4])
-@pytest.mark.parametrize("out_channels", [32])
-def test_brevitas_act_export_relu_imagenet(abits, out_channels):
-    ishape = (32, 15)
-
+@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
+def test_brevitas_act_export_relu_imagenet(abits, max_val):
+    out_channels = 32
+    ishape = (1, out_channels, 1, 1)
+    min_val = -1.0
     b_act = QuantReLU(
         bit_width=abits,
         quant_type=QuantType.INT,
@@ -83,7 +84,7 @@ def test_brevitas_act_export_relu_imagenet(abits, out_channels):
         restrict_scaling_type=RestrictValueType.LOG_FP,
         scaling_min_val=2e-16,
         max_val=6.0,
-        return_quant_tensor=False,
+        return_quant_tensor=True,
         per_channel_broadcastable_shape=(1, out_channels, 1, 1),
     )
     checkpoint = {
@@ -131,3 +132,28 @@ scaling_impl.learned_value": torch.tensor(
     }
     b_act.load_state_dict(checkpoint)
     bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(InferShapes())
+    inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
+        np.float32
+    )
+    idict = {model.graph.input[0].name: inp_tensor}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+    inp_tensor = torch.from_numpy(inp_tensor).float()
+    b_act.eval()
+    expected = b_act.forward(inp_tensor).tensor.detach().numpy()
+    if not np.isclose(produced, expected, atol=1e-3).all():
+        print(abits, max_val)
+        print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
+        if abits < 5:
+            print(
+                "thres:",
+                ", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]),
+            )
+        print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
+        print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
+        print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))
+
+    assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)
-- 
GitLab