Skip to content
Snippets Groups Projects
Unverified Commit feb0e9fb authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #158 from Xilinx/feature/move_chw_mul_past_dw_conv

Feature/move chw mul past dw conv
parents 74eb0d59 ad33d8d0
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,7 @@ from onnx import helper as oh
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.core.datatype import DataType
from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
......@@ -338,6 +339,71 @@ class MoveScalarMulPastConv(Transformation):
return (model, graph_modified)
class MoveMulPastDWConv(Transformation):
"""Move channelwise mul operations past depthwise conv operations. We want to have muls
next to each other such that they can be collapsed into a single mul."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if (
n.op_type == "Mul"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if (
consumer is not None
and consumer.op_type == "Conv"
and not model.is_join_node(consumer)
):
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
if A is None:
warnings.warn(
"""Mul weight tensor is not set. If it is a constant,
please use set_initializer to set the tensor."""
)
continue
conv_node = consumer
mul_node = n
start_name = mul_node.input[0]
conv_in_name = conv_node.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name)
ifm_ch = conv_in_shape[1]
group_attribute = get_by_name(consumer.attribute, "group")
if group_attribute is None:
continue
group_attribute = group_attribute.i
conv_out_name = conv_node.output[0]
conv_out_shape = model.get_tensor_shape(conv_out_name)
if A.shape == (1, ifm_ch, 1, 1) and ifm_ch == group_attribute:
# if the mul is channelwise and conv is depthwise,
# we can simply swap the order of ops
# rewire mul input to be conv input
conv_node.input[0] = start_name
model.set_tensor_shape(start_name, conv_in_shape)
model.set_tensor_datatype(start_name, DataType.FLOAT32)
# use old conv input tensor as conv output
conv_node.output[0] = conv_in_name
model.set_tensor_shape(conv_in_name, conv_out_shape)
model.set_tensor_datatype(conv_in_name, DataType.FLOAT32)
# use new conv output as new mul node input
mul_node.input[0] = conv_in_name
# use old conv output as new mul node output
mul_node.output[0] = conv_out_name
model.set_tensor_datatype(conv_out_name, DataType.FLOAT32)
# move mul node past conv node
graph.node.remove(mul_node)
graph.node.insert(node_ind, mul_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveLinearPastEltwiseAdd(Transformation):
"""Move linear operations (mul, add) past elementwise add operations where possible.
Specifically,matches and transforms the following patterns:
......
import pytest
from onnx import helper, TensorProto
from finn.custom_op.im2col import compute_conv_output_dim
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import gen_finn_dt_tensor
from finn.transformation.streamline.reorder import MoveMulPastDWConv
# input dimension
@pytest.mark.parametrize("ifm_dim", [4, 7])
# input channels
@pytest.mark.parametrize("ifm_ch", [2, 3])
# kernel size
@pytest.mark.parametrize("k", [2, 3])
# stride
@pytest.mark.parametrize("stride", [1, 2])
# padding
@pytest.mark.parametrize("pad_amt", [0, 1])
# depthwise
@pytest.mark.parametrize("dw", [0, 1])
def test_move_mul_past_dw_conv(ifm_dim, ifm_ch, k, stride, pad_amt, dw):
if dw == 1:
ofm_ch = ifm_ch
groups = ifm_ch
W_shape = [ofm_ch, 1, k, k]
else:
ofm_ch = ifm_ch + 2
groups = 1
W_shape = [ofm_ch, ifm_ch, k, k]
ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad_amt)
# set up onnx model
inp = helper.make_tensor_value_info(
"inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
)
mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [1, ifm_ch, 1, 1])
W = helper.make_tensor_value_info("W", TensorProto.FLOAT, W_shape)
outp = helper.make_tensor_value_info(
"outp", TensorProto.FLOAT, [1, ofm_ch, ofm_dim, ofm_dim]
)
Mul_node = helper.make_node("Mul", ["inp", "mul"], ["mul_out"])
Conv_node = helper.make_node(
"Conv",
["mul_out", "W"],
["outp"],
group=groups,
kernel_shape=[k, k],
pads=[pad_amt, pad_amt, pad_amt, pad_amt],
strides=[stride, stride],
)
graph = helper.make_graph(
nodes=[Mul_node, Conv_node],
name="mulpastconv_graph",
inputs=[inp],
outputs=[outp],
value_info=[mul, W],
)
model = helper.make_model(graph, producer_name="mulpastconv-model")
model = ModelWrapper(model)
inp_values = gen_finn_dt_tensor(DataType.INT2, [1, ifm_ch, ifm_dim, ifm_dim])
mul_values = gen_finn_dt_tensor(DataType.INT2, [1, ifm_ch, 1, 1])
W_values = gen_finn_dt_tensor(DataType.INT2, W_shape)
model.set_initializer("W", W_values)
model.set_initializer("mul", mul_values)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
idict = {"inp": inp_values}
odict = oxe.execute_onnx(model, idict, True)
out_before = odict["outp"]
# move channelwise multiplication past depthwise conv
model_transformed = model.transform(MoveMulPastDWConv())
odict = oxe.execute_onnx(model_transformed, idict, True)
out_after = odict["outp"]
assert (out_before == out_after).all()
if dw == 0:
assert model.graph.node[0].op_type == model_transformed.graph.node[0].op_type
assert model.graph.node[1].op_type == model_transformed.graph.node[1].op_type
else:
assert model.graph.node[0].op_type == model_transformed.graph.node[1].op_type
assert model.graph.node[1].op_type == model_transformed.graph.node[0].op_type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment