From 2564cd42ea857388461e11c689e5e9791ec25a5a Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 1 Nov 2019 15:03:23 +0000
Subject: [PATCH] [Test] add test_sign_to_thres

---
 tests/test_sign_to_thres.py | 33 +++++++++++++++++++++++++++++++++
 1 file changed, 33 insertions(+)
 create mode 100644 tests/test_sign_to_thres.py

diff --git a/tests/test_sign_to_thres.py b/tests/test_sign_to_thres.py
new file mode 100644
index 000000000..4d7f447f8
--- /dev/null
+++ b/tests/test_sign_to_thres.py
@@ -0,0 +1,33 @@
+import numpy as np
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+import finn.transformation.infer_shapes as si
+import finn.transformation.streamline as sl
+from finn.core.modelwrapper import ModelWrapper
+
+
+def test_sign_to_thres():
+    out0 = helper.make_tensor_value_info("out0", TensorProto.FLOAT, [6, 3, 2, 2])
+    graph_def = helper.make_graph(
+        nodes=[
+            helper.make_node("Sign", ["v"], ["out0"]),
+            helper.make_node("Relu", ["out0"], ["out1"]),
+        ],
+        name="test-model",
+        inputs=[helper.make_tensor_value_info("v", TensorProto.FLOAT, [6, 3, 2, 2])],
+        value_info=[out0],
+        outputs=[
+            helper.make_tensor_value_info("out1", TensorProto.FLOAT, [6, 3, 2, 2])
+        ],
+    )
+    model_def = helper.make_model(graph_def, producer_name="finn-test")
+    model = ModelWrapper(model_def)
+    model = model.transform_single(si.infer_shapes)
+    input_dict = {}
+    input_dict["v"] = np.random.randn(*[6, 3, 2, 2]).astype(np.float32)
+    expected = oxe.execute_onnx(model, input_dict)["out1"]
+    model = model.transform_single(sl.convert_sign_to_thres)
+    assert model.graph.node[0].op_type == "MultiThreshold"
+    produced = oxe.execute_onnx(model, input_dict)["out1"]
+    assert np.isclose(expected, produced, atol=1e-3).all()
-- 
GitLab