From c9dca488f7cd3579f13a259bd3a7f74d25715972 Mon Sep 17 00:00:00 2001 From: Tobi-Alonso <tobi.alonso@gmail.com> Date: Tue, 7 Jul 2020 17:51:01 +0100 Subject: [PATCH] [Transformation] InsertTopK absorbs linear scalar --- src/finn/transformation/insert_topk.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py index 213d2cedf..70012c1bc 100644 --- a/src/finn/transformation/insert_topk.py +++ b/src/finn/transformation/insert_topk.py @@ -46,6 +46,14 @@ class InsertTopK(Transformation): self.largest = largest self.sorted = sorted + def is_scalar_linear(self, model, node): + # if is linear + test = (node.op_type == "Mul") or (node.op_type == "Add") + if test: + init = model.get_initializer(node.input[1]) + test = test and (init is not None) and all(x == 1 for x in init.shape) + return test + def apply(self, model): # get name of output tensor graph_out_name = model.graph.output[0].name @@ -55,6 +63,17 @@ class InsertTopK(Transformation): if final_node.op_type == "TopK": return (model, False) else: + # remove any scalar linear transformations at graph output + # because TopK is invariant to them + while self.is_scalar_linear(model, final_node): + # remove the predecessor + final_node_input = model.get_tensor_valueinfo(final_node.input[0]) + model.graph.output.insert(0, final_node_input) + model.graph.output.pop(1) + model.graph.node.remove(final_node) + graph_out_name = model.graph.output[0].name + final_node = model.find_producer(graph_out_name) + out_shape = model.get_tensor_shape(graph_out_name) out_dtype = model.get_tensor_datatype(graph_out_name) # adjust shape -- GitLab