From 95664952c380182fa2eb909e0a194b6e7fb67765 Mon Sep 17 00:00:00 2001 From: Felix Jentzsch <45395194+fpjentzsch@users.noreply.github.com> Date: Tue, 23 Feb 2021 10:23:49 +0100 Subject: [PATCH] Add MoveMulPastMaxPool transformation to streamlining (#285) * Add MoveMulPastMaxPool transform to streamlining * Add second call to transformation within streamlining --- .../transformation/streamline/__init__.py | 3 + src/finn/transformation/streamline/reorder.py | 96 ++++++++++++++++--- .../streamline/test_move_mul_past_maxpool.py | 91 ++++++++++++++++++ 3 files changed, 179 insertions(+), 11 deletions(-) create mode 100755 tests/transformation/streamline/test_move_mul_past_maxpool.py diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index e78b798ff..876f8892d 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -59,6 +59,7 @@ from finn.transformation.streamline.reorder import ( MoveScalarAddPastMatMul, MoveAddPastConv, MoveScalarMulPastConv, + MoveMulPastMaxPool, ) from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds @@ -76,6 +77,7 @@ class Streamline(Transformation): ConvertDivToMul(), BatchNormToAffine(), ConvertSignToThres(), + MoveMulPastMaxPool(), AbsorbSignBiasIntoMultiThreshold(), MoveAddPastMul(), MoveScalarAddPastMatMul(), @@ -85,6 +87,7 @@ class Streamline(Transformation): MoveAddPastMul(), CollapseRepeatedAdd(), CollapseRepeatedMul(), + MoveMulPastMaxPool(), AbsorbAddIntoMultiThreshold(), FactorOutMulSignMagnitude(), AbsorbMulIntoMultiThreshold(), diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 08a011713..b23f9f149 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -425,12 +425,86 @@ class MoveMulPastDWConv(Transformation): return (model, graph_modified) +class MoveMulPastMaxPool(Transformation): + """Move non-negative scalar or channelwise mul operations past max pool operations. + We want to have muls next to each other such that they can be collapsed into a + single mul.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if ( + n.op_type == "Mul" + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): + consumer = model.find_consumer(n.output[0]) + if ( + consumer is not None + and consumer.op_type == "MaxPool" + and not model.is_join_node(consumer) + ): + mul_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + if A is None: + warnings.warn( + """Mul weight tensor is not set. If it is a constant, + please use set_initializer to set the tensor.""" + ) + continue + maxpool_node = consumer + mul_node = n + start_name = mul_node.input[0] + maxpool_in_name = maxpool_node.input[0] + maxpool_in_shape = model.get_tensor_shape(maxpool_in_name) + ifm_ch = maxpool_in_shape[1] + maxpool_out_name = maxpool_node.output[0] + maxpool_out_shape = model.get_tensor_shape(maxpool_out_name) + + # do not support non-2D MaxPool + kernel_shape = list( + get_by_name(maxpool_node.attribute, "kernel_shape").ints + ) + if len(kernel_shape) != 2: + continue + + # do not move negative multiplication factor(s) + if (A < 0).any(): + continue + + if all(x == 1 for x in A.shape) or A.shape == (1, ifm_ch, 1, 1): + # if the mul is scalar or channelwise, + # we can simply swap the order of ops + # rewire mul input to be maxpool input + maxpool_node.input[0] = start_name + model.set_tensor_shape(start_name, maxpool_in_shape) + model.set_tensor_datatype(start_name, DataType.FLOAT32) + # use old maxpool input tensor as maxpool output + maxpool_node.output[0] = maxpool_in_name + model.set_tensor_shape(maxpool_in_name, maxpool_out_shape) + model.set_tensor_datatype(maxpool_in_name, DataType.FLOAT32) + # use new maxpool output as new mul node input + mul_node.input[0] = maxpool_in_name + # use old maxpool output as new mul node output + mul_node.output[0] = maxpool_out_name + model.set_tensor_datatype(maxpool_out_name, DataType.FLOAT32) + # move mul node past maxpool node + graph.node.remove(mul_node) + graph.node.insert(node_ind, mul_node) + graph_modified = True + model = model.transform(InferShapes()) + return (model, graph_modified) + + class MoveLinearPastEltwiseAdd(Transformation): """Move linear operations (mul, add) past elementwise add operations where possible. - Specifically,matches and transforms the following patterns: - (x*C) + (y*C) -> (x + y) * C - (x+A) + (y+B) -> (x + y) + (A + B) - where x and y are dynamic inputs, A, B, C are constant tensors (in general). + Specifically,matches and transforms the following patterns: + (x*C) + (y*C) -> (x + y) * C + (x+A) + (y+B) -> (x + y) + (A + B) + where x and y are dynamic inputs, A, B, C are constant tensors (in general). """ def move_node(self, graph, n, prod0, prod1, node_ind): @@ -504,12 +578,12 @@ class MoveLinearPastEltwiseAdd(Transformation): class MoveScalarLinearPastInvariants(Transformation): """Move scalar linear operations (mul, add) past functions which are invariant - to them. Specifically, matches and transforms the following patterns: - f(x*C) -> f(x) * C - f(x+C) -> f(x) + C - where x is a dynamic input, C is a constant tensor. - Known f which obey this property are: Reshape, Flatten, Transpose, - GlobalAveragePool + to them. Specifically, matches and transforms the following patterns: + f(x*C) -> f(x) * C + f(x+C) -> f(x) + C + where x is a dynamic input, C is a constant tensor. + Known f which obey this property are: Reshape, Flatten, Transpose, + GlobalAveragePool """ def apply(self, model): @@ -604,7 +678,7 @@ class MakeMaxPoolNHWC(Transformation): class MoveOpPastFork(Transformation): """Move node operations past graph forks. Used when a node before a fork - can be merged with nodes in the branches + can be merged with nodes in the branches """ def __init__(self, op_name_list): diff --git a/tests/transformation/streamline/test_move_mul_past_maxpool.py b/tests/transformation/streamline/test_move_mul_past_maxpool.py new file mode 100755 index 000000000..f61284102 --- /dev/null +++ b/tests/transformation/streamline/test_move_mul_past_maxpool.py @@ -0,0 +1,91 @@ +import numpy as np +import pytest + +from onnx import helper, TensorProto +from finn.custom_op.general.maxpoolnhwc import compute_pool_output_dim +import finn.core.onnx_exec as oxe +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_shapes import InferShapes +from finn.util.basic import gen_finn_dt_tensor +from finn.transformation.streamline.reorder import MoveMulPastMaxPool + + +# input dimension +@pytest.mark.parametrize("ifm_dim", [4, 7]) +# input channels +@pytest.mark.parametrize("ifm_ch", [1, 3]) +# kernel size +@pytest.mark.parametrize("k", [2, 3]) +# stride +@pytest.mark.parametrize("stride", [1, 2]) +# padding +@pytest.mark.parametrize("pad", [0, 1]) +# channelwise or scalar mul +@pytest.mark.parametrize("cw", [0, 1]) +# negative mul +@pytest.mark.parametrize("negative", [0, 1]) +def test_move_mul_past_maxpool(ifm_dim, ifm_ch, k, stride, pad, cw, negative): + if cw == 1: + mul_shape = [1, ifm_ch, 1, 1] + else: + mul_shape = [1, 1, 1, 1] + + ofm_ch = ifm_ch + ofm_dim = compute_pool_output_dim(ifm_dim, k, stride, pad) + + # set up onnx model + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim] + ) + mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, mul_shape) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_ch, ofm_dim, ofm_dim] + ) + + Mul_node = helper.make_node("Mul", ["inp", "mul"], ["mul_out"]) + + Maxpool_node = helper.make_node( + "MaxPool", + ["mul_out"], + ["outp"], + kernel_shape=[k, k], + pads=[pad, pad, pad, pad], + strides=[stride, stride], + ) + + graph = helper.make_graph( + nodes=[Mul_node, Maxpool_node], + name="mulpastmaxpool_graph", + inputs=[inp], + outputs=[outp], + value_info=[mul], + ) + + model = helper.make_model(graph, producer_name="mulpastmaxpool-model") + model = ModelWrapper(model) + inp_values = gen_finn_dt_tensor(DataType.INT2, [1, ifm_ch, ifm_dim, ifm_dim]) + mul_values = np.random.random_sample(mul_shape).astype(np.float32) + if negative == 1: + mul_values = mul_values * (-1) + model.set_initializer("mul", mul_values) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + idict = {"inp": inp_values} + odict = oxe.execute_onnx(model, idict, True) + out_before = odict["outp"] + + # perform transformation + model_transformed = model.transform(MoveMulPastMaxPool()) + odict = oxe.execute_onnx(model_transformed, idict, True) + out_after = odict["outp"] + + assert (out_before == out_after).all() + + if negative == 1: + assert model.graph.node[0].op_type == model_transformed.graph.node[0].op_type + assert model.graph.node[1].op_type == model_transformed.graph.node[1].op_type + else: + assert model.graph.node[0].op_type == model_transformed.graph.node[1].op_type + assert model.graph.node[1].op_type == model_transformed.graph.node[0].op_type -- GitLab