diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py index 70012c1bc5602f08310148955920b252aff999a9..3ef6ef1b13c9b59e8a346c83daddaab4fdf6a859 100644 --- a/src/finn/transformation/insert_topk.py +++ b/src/finn/transformation/insert_topk.py @@ -52,6 +52,7 @@ class InsertTopK(Transformation): if test: init = model.get_initializer(node.input[1]) test = test and (init is not None) and all(x == 1 for x in init.shape) + test = test and init > 0 return test def apply(self, model): @@ -73,6 +74,8 @@ class InsertTopK(Transformation): model.graph.node.remove(final_node) graph_out_name = model.graph.output[0].name final_node = model.find_producer(graph_out_name) + if final_node is None: + break out_shape = model.get_tensor_shape(graph_out_name) out_dtype = model.get_tensor_datatype(graph_out_name)