diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py index 84906c7c23f48cae811a5103349886766de15d3e..213d2cedf92c0276e33fcf2b50e6966aeee8c847 100644 --- a/src/finn/transformation/insert_topk.py +++ b/src/finn/transformation/insert_topk.py @@ -32,6 +32,7 @@ from onnx import TensorProto from onnx import helper as oh from finn.transformation import Transformation +from finn.core.datatype import DataType class InsertTopK(Transformation): @@ -87,4 +88,9 @@ class InsertTopK(Transformation): # replace the existing output definition with topk indices model.graph.output.insert(0, topk_indices) model.graph.output.pop(1) + # set quantization annotation for indices + # minimal output dtype for TopK indices dependens on num. classes + # assuming UINT32 is large enough for now (FINN has currently no + # DataType.INT64) + model.set_tensor_datatype(topk_indices.name, DataType.UINT32) return (model, True)