Skip to content
Snippets Groups Projects
Commit f724b641 authored by auphelia's avatar auphelia
Browse files

[Streamline] Check for data layout is None in MoveTransposePastScalarMul

parent 78a11506
No related branches found
No related tags found
No related merge requests found
......@@ -643,6 +643,12 @@ class MoveTransposePastScalarMul(Transformation):
transp_out_shape = model.get_tensor_shape(middle_name)
transp_in_layout = model.get_tensor_layout(start_name)
transp_out_layout = model.get_tensor_layout(middle_name)
if transp_in_layout is None or transp_out_layout is None:
warnings.warn(
"""Datalayout is not set for tensors.
Transformation can't be applied."""
)
continue
if all(x == 1 for x in A.shape):
# if the mul is scalar, we can simply swap the order of ops
# rewire transpose input to be mul input
......@@ -659,6 +665,8 @@ class MoveTransposePastScalarMul(Transformation):
graph.node.remove(transp_node)
graph.node.insert(node_ind, transp_node)
graph_modified = True
model = model.transform(InferDataLayouts())
model = model.transform(InferShapes())
if graph_modified is True:
model = model.transform(InferDataLayouts())
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment