Skip to content
Snippets Groups Projects
Commit fd6bcd45 authored by auphelia's avatar auphelia
Browse files

[Streamline] Add propagation of tensor data layouts in MoveTransposePastScalarMul

parent 185022cb
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,7 @@ from onnx import helper as oh
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
......@@ -68,7 +69,9 @@ class MoveAddPastMul(Transformation):
A = model.get_initializer(mul_weight_name)
B = model.get_initializer(add_weight_name)
if (A is None) or (B is None):
warnings.warn("Mul or add does not have constant params, skipping")
warnings.warn(
"Mul or add does not have constant params, skipping"
)
continue
start_name = n.input[0]
middle_name = n.output[0]
......@@ -638,18 +641,24 @@ class MoveTransposePastScalarMul(Transformation):
end_name = mul_node.output[0]
transp_in_shape = model.get_tensor_shape(start_name)
transp_out_shape = model.get_tensor_shape(middle_name)
transp_in_layout = model.get_tensor_layout(start_name)
transp_out_layout = model.get_tensor_layout(middle_name)
if all(x == 1 for x in A.shape):
# if the mul is scalar, we can simply swap the order of ops
# rewire transpose input to be mul input
mul_node.input[0] = start_name
model.set_tensor_shape(start_name, transp_in_shape)
model.set_tensor_layout(start_name, transp_in_layout)
mul_node.output[0] = middle_name
model.set_tensor_shape(middle_name, transp_in_shape)
model.set_tensor_layout(middle_name, transp_in_layout)
transp_node.input[0] = middle_name
transp_node.output[0] = end_name
model.set_tensor_shape(end_name, transp_out_shape)
model.set_tensor_layout(end_name, transp_out_layout)
graph.node.remove(transp_node)
graph.node.insert(node_ind, transp_node)
graph_modified = True
model = model.transform(InferDataLayouts())
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment