From 44892ae4806f0db84a2332affec5573fdcacfcb4 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Wed, 13 May 2020 21:30:19 +0100
Subject: [PATCH] Scaled and non-scaled QuantHardTanh activation tests
 separated. Non-scaled test includes cases with scale parameter == 1. In
 scaled test, bipolar activation is excluded as it's not currently supported

---
 ...revitas_non_scaled_QuantHardTanh_export.py | 44 +++++++++++++++++++
 ... test_brevitas_scaled_QHardTanh_export.py} |  2 +-
 2 files changed, 45 insertions(+), 1 deletion(-)
 create mode 100644 tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
 rename tests/brevitas/{test_brevitas_act_export.py => test_brevitas_scaled_QHardTanh_export.py} (98%)

diff --git a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
new file mode 100644
index 000000000..a50e8eae7
--- /dev/null
+++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
@@ -0,0 +1,44 @@
+import onnx  # noqa
+import numpy as np
+import torch
+import brevitas.onnx as bo
+from brevitas.nn import QuantHardTanh
+from brevitas.core.restrict_val import RestrictValueType
+from brevitas.core.scaling import ScalingImplType
+from models.common import get_quant_type
+import pytest
+from finn.core.modelwrapper import ModelWrapper
+import finn.core.onnx_exec as oxe
+from finn.transformation.infer_shapes import InferShapes
+
+export_onnx_path = "test_act.onnx"
+
+
+@pytest.mark.parametrize("abits", [1, 2, 4, 8])
+@pytest.mark.parametrize("narrow_range", [False, True])
+@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)])
+def test_brevitas_act_export(abits, narrow_range, max_val):
+    act_quant_type = get_quant_type(abits)
+    min_val = -1.0
+    ishape = (1, 10)
+    b_act = QuantHardTanh(
+        bit_width=abits,
+        quant_type=act_quant_type,
+        max_val=max_val,
+        min_val=min_val,
+        restrict_scaling_type=RestrictValueType.LOG_FP,
+        scaling_impl_type=ScalingImplType.CONST,
+        narrow_range=narrow_range,
+    )
+    bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(InferShapes())
+    inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
+        np.float32
+    )
+    idict = {model.graph.input[0].name: inp_tensor}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+    inp_tensor = torch.from_numpy(inp_tensor).float()
+    expected = b_act.forward(inp_tensor).detach().numpy()
+    assert np.isclose(produced, expected, atol=1e-3).all()
diff --git a/tests/brevitas/test_brevitas_act_export.py b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py
similarity index 98%
rename from tests/brevitas/test_brevitas_act_export.py
rename to tests/brevitas/test_brevitas_scaled_QHardTanh_export.py
index 77ad1fb1c..02d7e27c8 100644
--- a/tests/brevitas/test_brevitas_act_export.py
+++ b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py
@@ -14,7 +14,7 @@ from finn.transformation.infer_shapes import InferShapes
 export_onnx_path = "test_act.onnx"
 
 
-@pytest.mark.parametrize("abits", [1, 2, 4, 8])
+@pytest.mark.parametrize("abits", [2, 4, 8])
 @pytest.mark.parametrize("narrow_range", [False, True])
 @pytest.mark.parametrize("min_val", [-1.0, -(1 - 2 ** (-7)), -2])
 @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7), 2])
-- 
GitLab