Skip to content
Snippets Groups Projects
Unverified Commit a6c29deb authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #168 from Xilinx/feature/mv_transp_past_scal_mul

Feature/mv transp past scal mul
parents 2b21a455 fbd885ac
No related branches found
No related tags found
No related merge requests found
...@@ -32,6 +32,7 @@ from onnx import helper as oh ...@@ -32,6 +32,7 @@ from onnx import helper as oh
from finn.transformation import Transformation from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.core.onnx_exec import execute_node from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name from finn.util.basic import get_by_name
...@@ -68,8 +69,11 @@ class MoveAddPastMul(Transformation): ...@@ -68,8 +69,11 @@ class MoveAddPastMul(Transformation):
add_weight_name = n.input[1] add_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name) A = model.get_initializer(mul_weight_name)
B = model.get_initializer(add_weight_name) B = model.get_initializer(add_weight_name)
assert A is not None, "Initializer for mul weights is not set." if (A is None) or (B is None):
assert B is not None, "Initializer for add weights is not set." warnings.warn(
"Mul or add does not have constant params, skipping"
)
continue
start_name = n.input[0] start_name = n.input[0]
middle_name = n.output[0] middle_name = n.output[0]
end_name = consumer.output[0] end_name = consumer.output[0]
...@@ -124,8 +128,9 @@ class MoveScalarMulPastMatMul(Transformation): ...@@ -124,8 +128,9 @@ class MoveScalarMulPastMatMul(Transformation):
matmul_weight_name = consumer.input[1] matmul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name) A = model.get_initializer(mul_weight_name)
W = model.get_initializer(matmul_weight_name) W = model.get_initializer(matmul_weight_name)
assert A is not None, "Initializer for mul weights is not set." if (A is None) or (W is None):
assert W is not None, "Initializer for matmul weights is not set." warnings.warn("MatMul or Mul params are not constant, skipping")
continue
start_name = n.input[0] start_name = n.input[0]
middle_name = n.output[0] middle_name = n.output[0]
end_name = consumer.output[0] end_name = consumer.output[0]
...@@ -181,8 +186,9 @@ class MoveScalarAddPastMatMul(Transformation): ...@@ -181,8 +186,9 @@ class MoveScalarAddPastMatMul(Transformation):
matmul_weight_name = consumer.input[1] matmul_weight_name = consumer.input[1]
A = model.get_initializer(add_weight_name) A = model.get_initializer(add_weight_name)
W = model.get_initializer(matmul_weight_name) W = model.get_initializer(matmul_weight_name)
assert A is not None, "Initializer for add weights is not set." if (A is None) or (W is None):
assert W is not None, "Initializer for matmul weights is not set." warnings.warn("MatMul or Add params are not constant, skipping")
continue
start_name = n.input[0] start_name = n.input[0]
middle_name = n.output[0] middle_name = n.output[0]
end_name = consumer.output[0] end_name = consumer.output[0]
...@@ -243,7 +249,9 @@ class MoveScalarAddPastConv(Transformation): ...@@ -243,7 +249,9 @@ class MoveScalarAddPastConv(Transformation):
conv_in_name = consumer.input[0] conv_in_name = consumer.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name) conv_in_shape = model.get_tensor_shape(conv_in_name)
A = model.get_initializer(add_weight_name) A = model.get_initializer(add_weight_name)
assert A is not None, "Initializer for add weights is not set." if A is None:
warnings.warn("Add param is not constant, skipping")
continue
start_name = n.input[0] start_name = n.input[0]
end_name = consumer.output[0] end_name = consumer.output[0]
conv_out_shape = model.get_tensor_shape(end_name) conv_out_shape = model.get_tensor_shape(end_name)
...@@ -311,7 +319,9 @@ class MoveScalarMulPastConv(Transformation): ...@@ -311,7 +319,9 @@ class MoveScalarMulPastConv(Transformation):
): ):
mul_weight_name = n.input[1] mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name) A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set." if A is None:
warnings.warn("Mul param is not constant, skipping")
continue
conv_node = consumer conv_node = consumer
mul_node = n mul_node = n
start_name = mul_node.input[0] start_name = mul_node.input[0]
...@@ -663,3 +673,66 @@ class MoveMaxPoolPastMultiThreshold(Transformation): ...@@ -663,3 +673,66 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
model = model.transform(InferShapes()) model = model.transform(InferShapes())
return (model, graph_modified) return (model, graph_modified)
class MoveTransposePastScalarMul(Transformation):
"""Moves a Transpose node past a scalar Mul node"""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if (
n.op_type == "Transpose"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if (
consumer is not None
and consumer.op_type == "Mul"
and not model.is_join_node(consumer)
):
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
if A is None:
warnings.warn("Mul param is not constant, skipping")
continue
transp_node = n
mul_node = consumer
start_name = transp_node.input[0]
middle_name = transp_node.output[0]
end_name = mul_node.output[0]
transp_in_shape = model.get_tensor_shape(start_name)
transp_out_shape = model.get_tensor_shape(middle_name)
transp_in_layout = model.get_tensor_layout(start_name)
transp_out_layout = model.get_tensor_layout(middle_name)
if transp_in_layout is None or transp_out_layout is None:
warnings.warn(
"""Datalayout is not set for tensors.
Transformation can't be applied."""
)
continue
if all(x == 1 for x in A.shape):
# if the mul is scalar, we can simply swap the order of ops
# rewire transpose input to be mul input
mul_node.input[0] = start_name
model.set_tensor_shape(start_name, transp_in_shape)
model.set_tensor_layout(start_name, transp_in_layout)
mul_node.output[0] = middle_name
model.set_tensor_shape(middle_name, transp_in_shape)
model.set_tensor_layout(middle_name, transp_in_layout)
transp_node.input[0] = middle_name
transp_node.output[0] = end_name
model.set_tensor_shape(end_name, transp_out_shape)
model.set_tensor_layout(end_name, transp_out_layout)
graph.node.remove(transp_node)
graph.node.insert(node_ind, transp_node)
graph_modified = True
if graph_modified is True:
model = model.transform(InferDataLayouts())
model = model.transform(InferShapes())
return (model, graph_modified)
import pytest
import numpy as np
from onnx import TensorProto, helper
from finn.core.modelwrapper import ModelWrapper
import finn.core.data_layout as DataLayout
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
from finn.transformation.streamline.reorder import MoveTransposePastScalarMul
import finn.core.onnx_exec as oxe
# permutation of transpose node
@pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 1, 3, 2], [3, 2, 0, 1]])
# scalar mul
@pytest.mark.parametrize("scalar", [True, False])
# data layout
@pytest.mark.parametrize("data_layout", [None, DataLayout.NHWC, DataLayout.NCHW])
def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 2, 3, 4])
# to determine out_size we need to calculate with "perm" for this test case
dummy_in = np.random.uniform(low=0, high=1, size=(1, 2, 3, 4)).astype(np.float32)
out_size = dummy_in.transpose(tuple(perm)).shape
if scalar is True:
a0_size = []
else:
a0_size = out_size
a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, a0_size)
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, out_size)
transp_node = helper.make_node("Transpose", ["inp"], ["transp_out"], perm=perm)
mul_node = helper.make_node("Mul", ["transp_out", "a0"], ["outp"])
graph = helper.make_graph(
nodes=[transp_node, mul_node],
name="mv-transpose-graph",
inputs=[inp],
outputs=[outp],
value_info=[a0],
)
model = helper.make_model(graph, producer_name="mv_transpose_model")
model = ModelWrapper(model)
# initialize values
a0_values = np.random.uniform(low=0, high=1, size=tuple(a0_size)).astype(np.float32)
model.set_initializer("a0", a0_values)
if data_layout is not None:
model.set_tensor_layout("inp", data_layout)
model = model.transform(InferDataLayouts())
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
# compare execution before and after transformation
inp_values = np.random.uniform(low=0, high=1, size=(1, 2, 3, 4)).astype(np.float32)
idict = {model.graph.input[0].name: inp_values}
model_transformed = model.transform(MoveTransposePastScalarMul())
assert oxe.compare_execution(model, model_transformed, idict)
# check if order changed
if scalar is True and data_layout is not None:
assert model_transformed.graph.node[0] != model.graph.node[0]
assert model_transformed.graph.node[1] != model.graph.node[1]
assert model_transformed.graph.node[0].op_type == "Mul"
assert model_transformed.graph.node[1].op_type == "Transpose"
mul_input = model_transformed.graph.node[0].input[0]
mul_output = model_transformed.graph.node[0].output[0]
assert model_transformed.get_tensor_layout(mul_input) == data_layout
assert model_transformed.get_tensor_layout(mul_output) == data_layout
else:
assert model_transformed.graph.node[0] == model.graph.node[0]
assert model_transformed.graph.node[1] == model.graph.node[1]
if data_layout is not None:
mul_input = model_transformed.graph.node[1].input[0]
mul_output = model_transformed.graph.node[1].output[0]
assert model_transformed.get_tensor_layout(mul_input) != data_layout
assert model_transformed.get_tensor_layout(mul_output) != data_layout
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment