From 36700c2130a1bc33e13e097345ace940fe28cdfe Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Fri, 26 Jun 2020 11:34:48 +0100 Subject: [PATCH] [Streamline] Add MoveFlattenPastTopK transformation --- src/finn/transformation/streamline/reorder.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index b46b82c77..23a2bfa5b 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -29,8 +29,10 @@ import numpy as np import warnings from onnx import helper as oh +from onnx import TensorProto from finn.transformation import Transformation +import finn.core.data_layout as DataLayout from finn.transformation.infer_shapes import InferShapes from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name @@ -597,3 +599,64 @@ class MoveMaxPoolPastMultiThreshold(Transformation): model = model.transform(InferShapes()) return (model, graph_modified) + + +class MoveFlattenPastTopK(Transformation): + """Move flatten node past a succeeding topk node, if the "axis" attribute in topk + is set to -1 and the data layout before the flatten is NHWC with H=W=1""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Flatten": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "TopK": + axis = get_by_name(consumer.attribute, "axis") + if axis is None or axis.i != -1: + continue + start_name = n.input[0] + data_layout = model.get_tensor_layout(start_name) + if data_layout != DataLayout.NHWC: + warnings.warn( + """Transformation can't be applied. The input + to flatten has to have DataLayout.NHWC""" + ) + continue + (b, h, w, c) = model.get_tensor_shape(start_name) + if h != 1 or w != 1: + continue + # get parameter k from topk + k = model.get_tensor_shape(consumer.output[1])[-1] + + # swap conections + # new tensor because dims change + middle_name = model.make_new_valueinfo_name() + topk_indices = oh.make_tensor_value_info( + middle_name, TensorProto.INT64, [b, h, w, k] + ) + end_name = consumer.output[1] + graph.value_info.append(topk_indices) + + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + + # set inputs and outputs correctly + consumer.input[0] = start_name + consumer.output[1] = middle_name + model.set_tensor_shape(consumer.output[0], (b, h, w, k)) + + n.input[0] = middle_name + n.output[0] = end_name + + # insert them back in + graph.node.insert(node_ind - 1, consumer) + graph.node.insert(node_ind, n) + + graph_modified = True + + model = model.transform(InferShapes()) + return (model, graph_modified) -- GitLab