diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py index a6cd659e07d6796a18291fd918a33a7c7bcf0ad9..84906c7c23f48cae811a5103349886766de15d3e 100644 --- a/src/finn/transformation/insert_topk.py +++ b/src/finn/transformation/insert_topk.py @@ -31,12 +31,12 @@ import numpy as np from onnx import TensorProto from onnx import helper as oh -from finn.custom_op.registry import getCustomOp from finn.transformation import Transformation class InsertTopK(Transformation): - """Add TopK node at the network output.""" + """Add TopK node at the network output and replace the graph output with + the TopK indices.""" def __init__(self, k=5, axis=-1, largest=1, sorted=1): super().__init__() @@ -48,7 +48,7 @@ class InsertTopK(Transformation): def apply(self, model): # get name of output tensor graph_out_name = model.graph.output[0].name - # find final node + # find final node final_node = model.find_producer(graph_out_name) # if a top-select op is already present, do nothing if final_node.op_type == "TopK": @@ -56,7 +56,7 @@ class InsertTopK(Transformation): else: out_shape = model.get_tensor_shape(graph_out_name) out_dtype = model.get_tensor_datatype(graph_out_name) - #adjust shape + # adjust shape out_shape[self.axis] = self.k # make new buffer k_tensor = np.array([self.k]).astype(np.int64) @@ -70,22 +70,21 @@ class InsertTopK(Transformation): model.make_new_valueinfo_name(), TensorProto.INT64, out_shape ) model.graph.value_info.append(k_value) - model.set_tensor_datatype(k_value.name, out_dtype)#TODO set to int64 + model.set_tensor_datatype(k_value.name, out_dtype) # TODO set to int64 model.graph.value_info.append(topk_values) model.set_tensor_datatype(topk_values.name, out_dtype) - #create and append topk node + # create and append topk node model.set_initializer(k_value.name, k_tensor) topk_node = oh.make_node( - 'TopK', + "TopK", inputs=[graph_out_name, k_value.name], outputs=[topk_values.name, topk_indices.name], axis=self.axis, largest=self.largest, - sorted=self.sorted + sorted=self.sorted, ) model.graph.node.append(topk_node) - #replace the existing output definition with topk indices - model.graph.output.insert(0,topk_indices) + # replace the existing output definition with topk indices + model.graph.output.insert(0, topk_indices) model.graph.output.pop(1) return (model, True) -