Skip to content
Snippets Groups Projects
Unverified Commit 95664952 authored by Felix Jentzsch's avatar Felix Jentzsch Committed by GitHub
Browse files

Add MoveMulPastMaxPool transformation to streamlining (#285)

* Add MoveMulPastMaxPool transform to streamlining

* Add second call to transformation within streamlining
parent 7776475f
No related branches found
No related tags found
No related merge requests found
......@@ -59,6 +59,7 @@ from finn.transformation.streamline.reorder import (
MoveScalarAddPastMatMul,
MoveAddPastConv,
MoveScalarMulPastConv,
MoveMulPastMaxPool,
)
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
......@@ -76,6 +77,7 @@ class Streamline(Transformation):
ConvertDivToMul(),
BatchNormToAffine(),
ConvertSignToThres(),
MoveMulPastMaxPool(),
AbsorbSignBiasIntoMultiThreshold(),
MoveAddPastMul(),
MoveScalarAddPastMatMul(),
......@@ -85,6 +87,7 @@ class Streamline(Transformation):
MoveAddPastMul(),
CollapseRepeatedAdd(),
CollapseRepeatedMul(),
MoveMulPastMaxPool(),
AbsorbAddIntoMultiThreshold(),
FactorOutMulSignMagnitude(),
AbsorbMulIntoMultiThreshold(),
......
......@@ -425,12 +425,86 @@ class MoveMulPastDWConv(Transformation):
return (model, graph_modified)
class MoveMulPastMaxPool(Transformation):
"""Move non-negative scalar or channelwise mul operations past max pool operations.
We want to have muls next to each other such that they can be collapsed into a
single mul."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if (
n.op_type == "Mul"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if (
consumer is not None
and consumer.op_type == "MaxPool"
and not model.is_join_node(consumer)
):
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
if A is None:
warnings.warn(
"""Mul weight tensor is not set. If it is a constant,
please use set_initializer to set the tensor."""
)
continue
maxpool_node = consumer
mul_node = n
start_name = mul_node.input[0]
maxpool_in_name = maxpool_node.input[0]
maxpool_in_shape = model.get_tensor_shape(maxpool_in_name)
ifm_ch = maxpool_in_shape[1]
maxpool_out_name = maxpool_node.output[0]
maxpool_out_shape = model.get_tensor_shape(maxpool_out_name)
# do not support non-2D MaxPool
kernel_shape = list(
get_by_name(maxpool_node.attribute, "kernel_shape").ints
)
if len(kernel_shape) != 2:
continue
# do not move negative multiplication factor(s)
if (A < 0).any():
continue
if all(x == 1 for x in A.shape) or A.shape == (1, ifm_ch, 1, 1):
# if the mul is scalar or channelwise,
# we can simply swap the order of ops
# rewire mul input to be maxpool input
maxpool_node.input[0] = start_name
model.set_tensor_shape(start_name, maxpool_in_shape)
model.set_tensor_datatype(start_name, DataType.FLOAT32)
# use old maxpool input tensor as maxpool output
maxpool_node.output[0] = maxpool_in_name
model.set_tensor_shape(maxpool_in_name, maxpool_out_shape)
model.set_tensor_datatype(maxpool_in_name, DataType.FLOAT32)
# use new maxpool output as new mul node input
mul_node.input[0] = maxpool_in_name
# use old maxpool output as new mul node output
mul_node.output[0] = maxpool_out_name
model.set_tensor_datatype(maxpool_out_name, DataType.FLOAT32)
# move mul node past maxpool node
graph.node.remove(mul_node)
graph.node.insert(node_ind, mul_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveLinearPastEltwiseAdd(Transformation):
"""Move linear operations (mul, add) past elementwise add operations where possible.
Specifically,matches and transforms the following patterns:
(x*C) + (y*C) -> (x + y) * C
(x+A) + (y+B) -> (x + y) + (A + B)
where x and y are dynamic inputs, A, B, C are constant tensors (in general).
Specifically,matches and transforms the following patterns:
(x*C) + (y*C) -> (x + y) * C
(x+A) + (y+B) -> (x + y) + (A + B)
where x and y are dynamic inputs, A, B, C are constant tensors (in general).
"""
def move_node(self, graph, n, prod0, prod1, node_ind):
......@@ -504,12 +578,12 @@ class MoveLinearPastEltwiseAdd(Transformation):
class MoveScalarLinearPastInvariants(Transformation):
"""Move scalar linear operations (mul, add) past functions which are invariant
to them. Specifically, matches and transforms the following patterns:
f(x*C) -> f(x) * C
f(x+C) -> f(x) + C
where x is a dynamic input, C is a constant tensor.
Known f which obey this property are: Reshape, Flatten, Transpose,
GlobalAveragePool
to them. Specifically, matches and transforms the following patterns:
f(x*C) -> f(x) * C
f(x+C) -> f(x) + C
where x is a dynamic input, C is a constant tensor.
Known f which obey this property are: Reshape, Flatten, Transpose,
GlobalAveragePool
"""
def apply(self, model):
......@@ -604,7 +678,7 @@ class MakeMaxPoolNHWC(Transformation):
class MoveOpPastFork(Transformation):
"""Move node operations past graph forks. Used when a node before a fork
can be merged with nodes in the branches
can be merged with nodes in the branches
"""
def __init__(self, op_name_list):
......
import numpy as np
import pytest
from onnx import helper, TensorProto
from finn.custom_op.general.maxpoolnhwc import compute_pool_output_dim
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import gen_finn_dt_tensor
from finn.transformation.streamline.reorder import MoveMulPastMaxPool
# input dimension
@pytest.mark.parametrize("ifm_dim", [4, 7])
# input channels
@pytest.mark.parametrize("ifm_ch", [1, 3])
# kernel size
@pytest.mark.parametrize("k", [2, 3])
# stride
@pytest.mark.parametrize("stride", [1, 2])
# padding
@pytest.mark.parametrize("pad", [0, 1])
# channelwise or scalar mul
@pytest.mark.parametrize("cw", [0, 1])
# negative mul
@pytest.mark.parametrize("negative", [0, 1])
def test_move_mul_past_maxpool(ifm_dim, ifm_ch, k, stride, pad, cw, negative):
if cw == 1:
mul_shape = [1, ifm_ch, 1, 1]
else:
mul_shape = [1, 1, 1, 1]
ofm_ch = ifm_ch
ofm_dim = compute_pool_output_dim(ifm_dim, k, stride, pad)
# set up onnx model
inp = helper.make_tensor_value_info(
"inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
)
mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, mul_shape)
outp = helper.make_tensor_value_info(
"outp", TensorProto.FLOAT, [1, ofm_ch, ofm_dim, ofm_dim]
)
Mul_node = helper.make_node("Mul", ["inp", "mul"], ["mul_out"])
Maxpool_node = helper.make_node(
"MaxPool",
["mul_out"],
["outp"],
kernel_shape=[k, k],
pads=[pad, pad, pad, pad],
strides=[stride, stride],
)
graph = helper.make_graph(
nodes=[Mul_node, Maxpool_node],
name="mulpastmaxpool_graph",
inputs=[inp],
outputs=[outp],
value_info=[mul],
)
model = helper.make_model(graph, producer_name="mulpastmaxpool-model")
model = ModelWrapper(model)
inp_values = gen_finn_dt_tensor(DataType.INT2, [1, ifm_ch, ifm_dim, ifm_dim])
mul_values = np.random.random_sample(mul_shape).astype(np.float32)
if negative == 1:
mul_values = mul_values * (-1)
model.set_initializer("mul", mul_values)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
idict = {"inp": inp_values}
odict = oxe.execute_onnx(model, idict, True)
out_before = odict["outp"]
# perform transformation
model_transformed = model.transform(MoveMulPastMaxPool())
odict = oxe.execute_onnx(model_transformed, idict, True)
out_after = odict["outp"]
assert (out_before == out_after).all()
if negative == 1:
assert model.graph.node[0].op_type == model_transformed.graph.node[0].op_type
assert model.graph.node[1].op_type == model_transformed.graph.node[1].op_type
else:
assert model.graph.node[0].op_type == model_transformed.graph.node[1].op_type
assert model.graph.node[1].op_type == model_transformed.graph.node[0].op_type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment