diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py
index 70012c1bc5602f08310148955920b252aff999a9..3ef6ef1b13c9b59e8a346c83daddaab4fdf6a859 100644
--- a/src/finn/transformation/insert_topk.py
+++ b/src/finn/transformation/insert_topk.py
@@ -52,6 +52,7 @@ class InsertTopK(Transformation):
         if test:
             init = model.get_initializer(node.input[1])
             test = test and (init is not None) and all(x == 1 for x in init.shape)
+            test = test and init > 0
         return test
 
     def apply(self, model):
@@ -73,6 +74,8 @@ class InsertTopK(Transformation):
                 model.graph.node.remove(final_node)
                 graph_out_name = model.graph.output[0].name
                 final_node = model.find_producer(graph_out_name)
+                if final_node is None:
+                    break
 
             out_shape = model.get_tensor_shape(graph_out_name)
             out_dtype = model.get_tensor_datatype(graph_out_name)