diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py index 0812cd7a21d7fef00586fb17ebc663eb6d4fa640..a6cd659e07d6796a18291fd918a33a7c7bcf0ad9 100644 --- a/src/finn/transformation/insert_topk.py +++ b/src/finn/transformation/insert_topk.py @@ -59,10 +59,7 @@ class InsertTopK(Transformation): #adjust shape out_shape[self.axis] = self.k # make new buffer - k_tensor = oh.make_tensor(name='k_tensor', - data_type=TensorProto.INT64, - dims=(1,), - vals=np.array([self.k]).astype(np.int64)) + k_tensor = np.array([self.k]).astype(np.int64) k_value = oh.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.INT64, [1] ) @@ -77,12 +74,7 @@ class InsertTopK(Transformation): model.graph.value_info.append(topk_values) model.set_tensor_datatype(topk_values.name, out_dtype) #create and append topk node - k_node = oh.make_node( - 'Constant', - inputs=[], - outputs=[k_value.name], - value=k_tensor - ) + model.set_initializer(k_value.name, k_tensor) topk_node = oh.make_node( 'TopK', inputs=[graph_out_name, k_value.name], @@ -91,7 +83,6 @@ class InsertTopK(Transformation): largest=self.largest, sorted=self.sorted ) - model.graph.node.append(k_node) model.graph.node.append(topk_node) #replace the existing output definition with topk indices model.graph.output.insert(0,topk_indices)