From 99035686c461cbb71a5b75154fcfda61118e340f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lucian=20Petric=C4=83?= <lucian.petrica@upb.ro>
Date: Thu, 30 Apr 2020 13:42:14 +0100
Subject: [PATCH] Switched to tensor intializer instead of constant node

---
 src/finn/transformation/insert_topk.py | 13 ++-----------
 1 file changed, 2 insertions(+), 11 deletions(-)

diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py
index 0812cd7a2..a6cd659e0 100644
--- a/src/finn/transformation/insert_topk.py
+++ b/src/finn/transformation/insert_topk.py
@@ -59,10 +59,7 @@ class InsertTopK(Transformation):
             #adjust shape
             out_shape[self.axis] = self.k
             # make new buffer
-            k_tensor = oh.make_tensor(name='k_tensor',
-                        data_type=TensorProto.INT64,
-                        dims=(1,),
-                        vals=np.array([self.k]).astype(np.int64))
+            k_tensor = np.array([self.k]).astype(np.int64)
             k_value = oh.make_tensor_value_info(
                 model.make_new_valueinfo_name(), TensorProto.INT64, [1]
             )
@@ -77,12 +74,7 @@ class InsertTopK(Transformation):
             model.graph.value_info.append(topk_values)
             model.set_tensor_datatype(topk_values.name, out_dtype)
             #create and append topk node
-            k_node = oh.make_node(
-                'Constant',
-                inputs=[],
-                outputs=[k_value.name],
-                value=k_tensor
-            )
+            model.set_initializer(k_value.name, k_tensor)
             topk_node = oh.make_node(
                 'TopK',
                 inputs=[graph_out_name, k_value.name],
@@ -91,7 +83,6 @@ class InsertTopK(Transformation):
                 largest=self.largest,
                 sorted=self.sorted
             )
-            model.graph.node.append(k_node)
             model.graph.node.append(topk_node)
             #replace the existing output definition with topk indices
             model.graph.output.insert(0,topk_indices)
-- 
GitLab