Skip to content
Snippets Groups Projects
Commit c9dca488 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[Transformation] InsertTopK absorbs linear scalar

parent ca3f3446
No related branches found
No related tags found
No related merge requests found
......@@ -46,6 +46,14 @@ class InsertTopK(Transformation):
self.largest = largest
self.sorted = sorted
def is_scalar_linear(self, model, node):
# if is linear
test = (node.op_type == "Mul") or (node.op_type == "Add")
if test:
init = model.get_initializer(node.input[1])
test = test and (init is not None) and all(x == 1 for x in init.shape)
return test
def apply(self, model):
# get name of output tensor
graph_out_name = model.graph.output[0].name
......@@ -55,6 +63,17 @@ class InsertTopK(Transformation):
if final_node.op_type == "TopK":
return (model, False)
else:
# remove any scalar linear transformations at graph output
# because TopK is invariant to them
while self.is_scalar_linear(model, final_node):
# remove the predecessor
final_node_input = model.get_tensor_valueinfo(final_node.input[0])
model.graph.output.insert(0, final_node_input)
model.graph.output.pop(1)
model.graph.node.remove(final_node)
graph_out_name = model.graph.output[0].name
final_node = model.find_producer(graph_out_name)
out_shape = model.get_tensor_shape(graph_out_name)
out_dtype = model.get_tensor_datatype(graph_out_name)
# adjust shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment