Skip to content
Snippets Groups Projects
Commit 36700c21 authored by auphelia's avatar auphelia
Browse files

[Streamline] Add MoveFlattenPastTopK transformation

parent cf54f617
No related branches found
No related tags found
No related merge requests found
......@@ -29,8 +29,10 @@
import numpy as np
import warnings
from onnx import helper as oh
from onnx import TensorProto
from finn.transformation import Transformation
import finn.core.data_layout as DataLayout
from finn.transformation.infer_shapes import InferShapes
from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name
......@@ -597,3 +599,64 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveFlattenPastTopK(Transformation):
"""Move flatten node past a succeeding topk node, if the "axis" attribute in topk
is set to -1 and the data layout before the flatten is NHWC with H=W=1"""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Flatten":
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "TopK":
axis = get_by_name(consumer.attribute, "axis")
if axis is None or axis.i != -1:
continue
start_name = n.input[0]
data_layout = model.get_tensor_layout(start_name)
if data_layout != DataLayout.NHWC:
warnings.warn(
"""Transformation can't be applied. The input
to flatten has to have DataLayout.NHWC"""
)
continue
(b, h, w, c) = model.get_tensor_shape(start_name)
if h != 1 or w != 1:
continue
# get parameter k from topk
k = model.get_tensor_shape(consumer.output[1])[-1]
# swap conections
# new tensor because dims change
middle_name = model.make_new_valueinfo_name()
topk_indices = oh.make_tensor_value_info(
middle_name, TensorProto.INT64, [b, h, w, k]
)
end_name = consumer.output[1]
graph.value_info.append(topk_indices)
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
# set inputs and outputs correctly
consumer.input[0] = start_name
consumer.output[1] = middle_name
model.set_tensor_shape(consumer.output[0], (b, h, w, k))
n.input[0] = middle_name
n.output[0] = end_name
# insert them back in
graph.node.insert(node_ind - 1, consumer)
graph.node.insert(node_ind, n)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment