From 5a5fbb794248a1439a9224cfa9753a809abb08fc Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Tue, 7 Jul 2020 19:16:55 +0100
Subject: [PATCH] [Test] Modify test_absorb_mul_into_topk now that insertTopk
 also absorbs scalar Mul

---
 .../test_absorb_mul_into_topk.py              | 25 +++++++++++++------
 1 file changed, 18 insertions(+), 7 deletions(-)

diff --git a/tests/transformation/test_absorb_mul_into_topk.py b/tests/transformation/test_absorb_mul_into_topk.py
index 1394220f7..6cc5c9847 100644
--- a/tests/transformation/test_absorb_mul_into_topk.py
+++ b/tests/transformation/test_absorb_mul_into_topk.py
@@ -34,7 +34,6 @@ from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
-from finn.transformation.insert_topk import InsertTopK
 from finn.transformation.streamline.absorb import AbsorbScalarMulIntoTopK
 import finn.core.onnx_exec as oxe
 
@@ -43,25 +42,35 @@ import finn.core.onnx_exec as oxe
 # parameter to indicate if mul parameter is scalar or not
 @pytest.mark.parametrize("scalar", [True, False])
 def test_absorb_mul_into_topk(mul_positive, scalar):
+    K = 5
     if scalar is True:
         shape = [1]
     else:
         shape = [1, 1, 1, 1000]
+
+    out_shape = [1, 1, 1, K]
     inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1000])
     a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, shape)
-    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1, 1, 1000])
-
-    mul_node = helper.make_node("Mul", ["inp", "a0"], ["outp"])
+    outp = helper.make_tensor_value_info("outp", TensorProto.INT64, out_shape)
+    k_value = helper.make_tensor_value_info("k_value", TensorProto.INT64, [1])
+    topk_values = helper.make_tensor_value_info(
+        "topk_values", TensorProto.FLOAT, out_shape
+    )
+    mul_node = helper.make_node("Mul", ["inp", "a0"], ["a1"])
+    top_k = helper.make_node(
+        "TopK", ["a1", "k_value"], ["topk_values", "outp"], largest=1, axis=-1, sorted=1
+    )
     mul_graph = helper.make_graph(
-        nodes=[mul_node],
+        nodes=[mul_node, top_k],
         name="mul-graph",
         inputs=[inp],
         outputs=[outp],
-        value_info=[a0],
+        value_info=[a0, k_value, topk_values],
     )
 
     model = helper.make_model(mul_graph, producer_name="mul_model")
     model = ModelWrapper(model)
+
     # initialize values
     if mul_positive is True:
         a0_values = np.random.uniform(low=0.1, high=1, size=tuple(shape)).astype(
@@ -72,7 +81,9 @@ def test_absorb_mul_into_topk(mul_positive, scalar):
             np.float32
         )
     model.set_initializer("a0", a0_values)
-    model = model.transform(InsertTopK())
+
+    k_tensor = np.array([K]).astype(np.int64)
+    model.set_initializer(k_value.name, k_tensor)
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
     model = model.transform(GiveUniqueNodeNames())
-- 
GitLab