From c9dca488f7cd3579f13a259bd3a7f74d25715972 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Tue, 7 Jul 2020 17:51:01 +0100
Subject: [PATCH] [Transformation] InsertTopK absorbs linear scalar

---
 src/finn/transformation/insert_topk.py | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py
index 213d2cedf..70012c1bc 100644
--- a/src/finn/transformation/insert_topk.py
+++ b/src/finn/transformation/insert_topk.py
@@ -46,6 +46,14 @@ class InsertTopK(Transformation):
         self.largest = largest
         self.sorted = sorted
 
+    def is_scalar_linear(self, model, node):
+        # if is linear
+        test = (node.op_type == "Mul") or (node.op_type == "Add")
+        if test:
+            init = model.get_initializer(node.input[1])
+            test = test and (init is not None) and all(x == 1 for x in init.shape)
+        return test
+
     def apply(self, model):
         # get name of output tensor
         graph_out_name = model.graph.output[0].name
@@ -55,6 +63,17 @@ class InsertTopK(Transformation):
         if final_node.op_type == "TopK":
             return (model, False)
         else:
+            # remove any scalar linear transformations at graph output
+            # because TopK is invariant to them
+            while self.is_scalar_linear(model, final_node):
+                # remove the predecessor
+                final_node_input = model.get_tensor_valueinfo(final_node.input[0])
+                model.graph.output.insert(0, final_node_input)
+                model.graph.output.pop(1)
+                model.graph.node.remove(final_node)
+                graph_out_name = model.graph.output[0].name
+                final_node = model.find_producer(graph_out_name)
+
             out_shape = model.get_tensor_shape(graph_out_name)
             out_dtype = model.get_tensor_datatype(graph_out_name)
             # adjust shape
-- 
GitLab