From e9aa670f167e361f910834fd37d857225d9bdb0f Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Thu, 25 Jun 2020 10:45:17 +0100 Subject: [PATCH] [Streamline] Add MoveFlattenPastAffine operation --- src/finn/transformation/streamline/reorder.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index b46b82c77..638c6ad5d 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -32,6 +32,7 @@ from onnx import helper as oh from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name from finn.custom_op.registry import getCustomOp @@ -597,3 +598,74 @@ class MoveMaxPoolPastMultiThreshold(Transformation): model = model.transform(InferShapes()) return (model, graph_modified) + + +class MoveFlattenPastAffine(Transformation): + """Moves a node that implements a (1, -1) reshape past a MatMul, Mul or Add node.""" + + def apply(self, model): + graph = model.graph + graph_modified = False + node_ind = 0 + for n in graph.node: + node_ind += 1 + if ( + n.op_type == "Flatten" + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): + consumer = model.find_consumer(n.output[0]) + if ( + consumer is not None + and ( + consumer.op_type == "MatMul" + or consumer.op_type == "Mul" + or consumer.op_type == "Add" + ) + and not model.is_join_node(consumer) + ): + # move flatten past operation and rewire tensors + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + op_param_name = consumer.input[1] + A = model.get_initializer(op_param_name) + if A is None: + warnings.warn("Param is not constant, skipping") + continue + op_in_dt = model.get_tensor_datatype(consumer.input[0]) + op_out_dt = model.get_tensor_datatype(consumer.output[0]) + start_shape = model.get_tensor_shape(start_name) + dummy_in = np.random.uniform(low=0, high=1, size=(start_shape)) + + if consumer.op_type == "MatMul": + dummy_out = np.matmul(dummy_in, A) + elif consumer.op_type == "Mul": + dummy_out = dummy_in * A + elif consumer.op_type == "Add": + dummy_out = dummy_in + A + + new_op = oh.make_node( + consumer.op_type, + [start_name, op_param_name], + [middle_name], + name=consumer.name, + ) + new_flatten = oh.make_node("Flatten", [middle_name], [end_name]) + graph.node.insert(node_ind, new_op) + graph.node.insert(node_ind + 1, new_flatten) + model.set_tensor_shape(middle_name, dummy_out.shape) + # because a flatten node doesn't change the datatype we need + # only the datatype of the op node + model.set_tensor_datatype(start_name, op_in_dt) + model.set_tensor_datatype(middle_name, op_out_dt) + model.set_tensor_datatype(end_name, op_out_dt) + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + graph_modified = True + + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + + return (model, graph_modified) -- GitLab