From efc1b85caabed6addfa07710475c819771d46e79 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 30 Apr 2020 14:38:59 +0100
Subject: [PATCH] [Transform] run pre-commit hooks for InsertTopK, add comment

---
 src/finn/transformation/insert_topk.py | 21 ++++++++++-----------
 1 file changed, 10 insertions(+), 11 deletions(-)

diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py
index a6cd659e0..84906c7c2 100644
--- a/src/finn/transformation/insert_topk.py
+++ b/src/finn/transformation/insert_topk.py
@@ -31,12 +31,12 @@ import numpy as np
 from onnx import TensorProto
 from onnx import helper as oh
 
-from finn.custom_op.registry import getCustomOp
 from finn.transformation import Transformation
 
 
 class InsertTopK(Transformation):
-    """Add TopK node at the network output."""
+    """Add TopK node at the network output and replace the graph output with
+    the TopK indices."""
 
     def __init__(self, k=5, axis=-1, largest=1, sorted=1):
         super().__init__()
@@ -48,7 +48,7 @@ class InsertTopK(Transformation):
     def apply(self, model):
         # get name of output tensor
         graph_out_name = model.graph.output[0].name
-        # find final node 
+        # find final node
         final_node = model.find_producer(graph_out_name)
         # if a top-select op is already present, do nothing
         if final_node.op_type == "TopK":
@@ -56,7 +56,7 @@ class InsertTopK(Transformation):
         else:
             out_shape = model.get_tensor_shape(graph_out_name)
             out_dtype = model.get_tensor_datatype(graph_out_name)
-            #adjust shape
+            # adjust shape
             out_shape[self.axis] = self.k
             # make new buffer
             k_tensor = np.array([self.k]).astype(np.int64)
@@ -70,22 +70,21 @@ class InsertTopK(Transformation):
                 model.make_new_valueinfo_name(), TensorProto.INT64, out_shape
             )
             model.graph.value_info.append(k_value)
-            model.set_tensor_datatype(k_value.name, out_dtype)#TODO set to int64
+            model.set_tensor_datatype(k_value.name, out_dtype)  # TODO set to int64
             model.graph.value_info.append(topk_values)
             model.set_tensor_datatype(topk_values.name, out_dtype)
-            #create and append topk node
+            # create and append topk node
             model.set_initializer(k_value.name, k_tensor)
             topk_node = oh.make_node(
-                'TopK',
+                "TopK",
                 inputs=[graph_out_name, k_value.name],
                 outputs=[topk_values.name, topk_indices.name],
                 axis=self.axis,
                 largest=self.largest,
-                sorted=self.sorted
+                sorted=self.sorted,
             )
             model.graph.node.append(topk_node)
-            #replace the existing output definition with topk indices
-            model.graph.output.insert(0,topk_indices)
+            # replace the existing output definition with topk indices
+            model.graph.output.insert(0, topk_indices)
             model.graph.output.pop(1)
             return (model, True)
-
-- 
GitLab