From 73b4930cc246490e43111e07c436b32cd677c203 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lucian=20Petric=C4=83?= <lucian.petrica@upb.ro>
Date: Wed, 29 Apr 2020 15:38:39 +0100
Subject: [PATCH] Draft code for transformation to insert a top-k node at the
 output of the graph

---
 src/finn/transformation/insert_topk.py | 101 +++++++++++++++++++++++++
 1 file changed, 101 insertions(+)
 create mode 100644 src/finn/transformation/insert_topk.py

diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py
new file mode 100644
index 000000000..30b0fff7a
--- /dev/null
+++ b/src/finn/transformation/insert_topk.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+
+from onnx import TensorProto
+from onnx import helper as oh
+
+from finn.custom_op.registry import getCustomOp
+from finn.transformation import Transformation
+
+
+class InsertTopK(Transformation):
+    """Add TopK node at the network output."""
+
+    def __init__(self, k=5, axis=-1, largest=1, sorted=1):
+        super().__init__()
+        self.k = k
+        self.axis = axis
+        self.largest = largest
+        self.sorted = sorted
+
+    def apply(self, model):
+        # get name of output tensor
+        graph_out_name = model.graph.output[0].name
+        # find final node 
+        final_node = model.find_producer(graph_out_name)
+        # if a top-select op is already present, do nothing
+        if final_node.op_type == "TopK":
+            return (model, False)
+        else:
+            out_shape = model.get_tensor_shape(graph_out_name)
+            out_dtype = model.get_tensor_datatype(graph_out_name)
+            #adjust shape
+            out_shape[self.axis] = self.k
+            import pdb; pdb.set_trace()
+            # make new buffer
+            k_tensor = oh.make_tensor(name='k_tensor',
+                        data_type=TensorProto.INT64,
+                        dims=(1,),
+                        vals=np.array([self.k]).astype(np.int64))
+            k_value = oh.make_tensor_value_info(
+                model.make_new_valueinfo_name(), TensorProto.INT64, [1]
+            )
+            topk_values = oh.make_tensor_value_info(
+                model.make_new_valueinfo_name(), TensorProto.FLOAT, out_shape
+            )
+            topk_indices = oh.make_tensor_value_info(
+                model.make_new_valueinfo_name(), TensorProto.INT64, out_shape
+            )
+            model.graph.value_info.append(k_value)
+            model.set_tensor_datatype(k_value.name, out_dtype)#TODO set to int64
+            model.graph.value_info.append(topk_values)
+            model.set_tensor_datatype(topk_values.name, out_dtype)
+            model.graph.value_info.append(topk_indices)
+            model.set_tensor_datatype(topk_indices.name, out_dtype)
+            #create and append topk node
+            k_node = oh.make_node(
+                'Constant',
+                inputs=[],
+                outputs=[k_value.name],
+                value=k_tensor
+            )
+            topk_node = oh.make_node(
+                'TopK',
+                inputs=[graph_out_name, k_value.name],
+                outputs=[topk_values.name, topk_indices.name],
+                axis=self.axis,
+
+            )
+            model.graph.node.append(k_node)
+            model.graph.node.append(topk_node)
+            model.graph.output[0].name = topk_values.name
+            print(topk_indices.name,topk_values.name)
+            return (model, True)
+
-- 
GitLab