From fd6bcd45829e0fe4f9e698e4bf5b069193e42d79 Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Fri, 26 Jun 2020 15:11:04 +0100 Subject: [PATCH] [Streamline] Add propagation of tensor data layouts in MoveTransposePastScalarMul --- src/finn/transformation/streamline/reorder.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 62301cee5..748c7420a 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -32,6 +32,7 @@ from onnx import helper as oh from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_data_layouts import InferDataLayouts from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name from finn.custom_op.registry import getCustomOp @@ -68,7 +69,9 @@ class MoveAddPastMul(Transformation): A = model.get_initializer(mul_weight_name) B = model.get_initializer(add_weight_name) if (A is None) or (B is None): - warnings.warn("Mul or add does not have constant params, skipping") + warnings.warn( + "Mul or add does not have constant params, skipping" + ) continue start_name = n.input[0] middle_name = n.output[0] @@ -638,18 +641,24 @@ class MoveTransposePastScalarMul(Transformation): end_name = mul_node.output[0] transp_in_shape = model.get_tensor_shape(start_name) transp_out_shape = model.get_tensor_shape(middle_name) + transp_in_layout = model.get_tensor_layout(start_name) + transp_out_layout = model.get_tensor_layout(middle_name) if all(x == 1 for x in A.shape): # if the mul is scalar, we can simply swap the order of ops # rewire transpose input to be mul input mul_node.input[0] = start_name model.set_tensor_shape(start_name, transp_in_shape) + model.set_tensor_layout(start_name, transp_in_layout) mul_node.output[0] = middle_name model.set_tensor_shape(middle_name, transp_in_shape) + model.set_tensor_layout(middle_name, transp_in_layout) transp_node.input[0] = middle_name transp_node.output[0] = end_name model.set_tensor_shape(end_name, transp_out_shape) + model.set_tensor_layout(end_name, transp_out_layout) graph.node.remove(transp_node) graph.node.insert(node_ind, transp_node) graph_modified = True + model = model.transform(InferDataLayouts()) model = model.transform(InferShapes()) return (model, graph_modified) -- GitLab