From 95664952c380182fa2eb909e0a194b6e7fb67765 Mon Sep 17 00:00:00 2001
From: Felix Jentzsch <45395194+fpjentzsch@users.noreply.github.com>
Date: Tue, 23 Feb 2021 10:23:49 +0100
Subject: [PATCH] Add MoveMulPastMaxPool transformation to streamlining (#285)

* Add MoveMulPastMaxPool transform to streamlining

* Add second call to transformation within streamlining
---
 .../transformation/streamline/__init__.py     |  3 +
 src/finn/transformation/streamline/reorder.py | 96 ++++++++++++++++---
 .../streamline/test_move_mul_past_maxpool.py  | 91 ++++++++++++++++++
 3 files changed, 179 insertions(+), 11 deletions(-)
 create mode 100755 tests/transformation/streamline/test_move_mul_past_maxpool.py

diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py
index e78b798ff..876f8892d 100644
--- a/src/finn/transformation/streamline/__init__.py
+++ b/src/finn/transformation/streamline/__init__.py
@@ -59,6 +59,7 @@ from finn.transformation.streamline.reorder import (
     MoveScalarAddPastMatMul,
     MoveAddPastConv,
     MoveScalarMulPastConv,
+    MoveMulPastMaxPool,
 )
 
 from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
@@ -76,6 +77,7 @@ class Streamline(Transformation):
             ConvertDivToMul(),
             BatchNormToAffine(),
             ConvertSignToThres(),
+            MoveMulPastMaxPool(),
             AbsorbSignBiasIntoMultiThreshold(),
             MoveAddPastMul(),
             MoveScalarAddPastMatMul(),
@@ -85,6 +87,7 @@ class Streamline(Transformation):
             MoveAddPastMul(),
             CollapseRepeatedAdd(),
             CollapseRepeatedMul(),
+            MoveMulPastMaxPool(),
             AbsorbAddIntoMultiThreshold(),
             FactorOutMulSignMagnitude(),
             AbsorbMulIntoMultiThreshold(),
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 08a011713..b23f9f149 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -425,12 +425,86 @@ class MoveMulPastDWConv(Transformation):
         return (model, graph_modified)
 
 
+class MoveMulPastMaxPool(Transformation):
+    """Move non-negative scalar or channelwise mul operations past max pool operations.
+    We want to have muls next to each other such that they can be collapsed into a
+    single mul."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "Mul"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
+                consumer = model.find_consumer(n.output[0])
+                if (
+                    consumer is not None
+                    and consumer.op_type == "MaxPool"
+                    and not model.is_join_node(consumer)
+                ):
+                    mul_weight_name = n.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    if A is None:
+                        warnings.warn(
+                            """Mul weight tensor is not set. If it is a constant,
+                                please use set_initializer to set the tensor."""
+                        )
+                        continue
+                    maxpool_node = consumer
+                    mul_node = n
+                    start_name = mul_node.input[0]
+                    maxpool_in_name = maxpool_node.input[0]
+                    maxpool_in_shape = model.get_tensor_shape(maxpool_in_name)
+                    ifm_ch = maxpool_in_shape[1]
+                    maxpool_out_name = maxpool_node.output[0]
+                    maxpool_out_shape = model.get_tensor_shape(maxpool_out_name)
+
+                    # do not support non-2D MaxPool
+                    kernel_shape = list(
+                        get_by_name(maxpool_node.attribute, "kernel_shape").ints
+                    )
+                    if len(kernel_shape) != 2:
+                        continue
+
+                    # do not move negative multiplication factor(s)
+                    if (A < 0).any():
+                        continue
+
+                    if all(x == 1 for x in A.shape) or A.shape == (1, ifm_ch, 1, 1):
+                        # if the mul is scalar or channelwise,
+                        # we can simply swap the order of ops
+                        # rewire mul input to be maxpool input
+                        maxpool_node.input[0] = start_name
+                        model.set_tensor_shape(start_name, maxpool_in_shape)
+                        model.set_tensor_datatype(start_name, DataType.FLOAT32)
+                        # use old maxpool input tensor as maxpool output
+                        maxpool_node.output[0] = maxpool_in_name
+                        model.set_tensor_shape(maxpool_in_name, maxpool_out_shape)
+                        model.set_tensor_datatype(maxpool_in_name, DataType.FLOAT32)
+                        # use new maxpool output as new mul node input
+                        mul_node.input[0] = maxpool_in_name
+                        # use old maxpool output as new mul node output
+                        mul_node.output[0] = maxpool_out_name
+                        model.set_tensor_datatype(maxpool_out_name, DataType.FLOAT32)
+                        # move mul node past maxpool node
+                        graph.node.remove(mul_node)
+                        graph.node.insert(node_ind, mul_node)
+                        graph_modified = True
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
 class MoveLinearPastEltwiseAdd(Transformation):
     """Move linear operations (mul, add) past elementwise add operations where possible.
-       Specifically,matches and transforms the following patterns:
-       (x*C) + (y*C) -> (x + y) * C
-       (x+A) + (y+B) -> (x + y) + (A + B)
-       where x and y are dynamic inputs, A, B, C are constant tensors (in general).
+    Specifically,matches and transforms the following patterns:
+    (x*C) + (y*C) -> (x + y) * C
+    (x+A) + (y+B) -> (x + y) + (A + B)
+    where x and y are dynamic inputs, A, B, C are constant tensors (in general).
     """
 
     def move_node(self, graph, n, prod0, prod1, node_ind):
@@ -504,12 +578,12 @@ class MoveLinearPastEltwiseAdd(Transformation):
 
 class MoveScalarLinearPastInvariants(Transformation):
     """Move scalar linear operations (mul, add) past functions which are invariant
-       to them. Specifically, matches and transforms the following patterns:
-       f(x*C) -> f(x) * C
-       f(x+C) -> f(x) + C
-       where x is a dynamic input, C is a constant tensor.
-       Known f which obey this property are: Reshape, Flatten, Transpose,
-       GlobalAveragePool
+    to them. Specifically, matches and transforms the following patterns:
+    f(x*C) -> f(x) * C
+    f(x+C) -> f(x) + C
+    where x is a dynamic input, C is a constant tensor.
+    Known f which obey this property are: Reshape, Flatten, Transpose,
+    GlobalAveragePool
     """
 
     def apply(self, model):
@@ -604,7 +678,7 @@ class MakeMaxPoolNHWC(Transformation):
 
 class MoveOpPastFork(Transformation):
     """Move node operations past graph forks. Used when a node before a fork
-     can be merged with nodes in the branches
+    can be merged with nodes in the branches
     """
 
     def __init__(self, op_name_list):
diff --git a/tests/transformation/streamline/test_move_mul_past_maxpool.py b/tests/transformation/streamline/test_move_mul_past_maxpool.py
new file mode 100755
index 000000000..f61284102
--- /dev/null
+++ b/tests/transformation/streamline/test_move_mul_past_maxpool.py
@@ -0,0 +1,91 @@
+import numpy as np
+import pytest
+
+from onnx import helper, TensorProto
+from finn.custom_op.general.maxpoolnhwc import compute_pool_output_dim
+import finn.core.onnx_exec as oxe
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
+from finn.util.basic import gen_finn_dt_tensor
+from finn.transformation.streamline.reorder import MoveMulPastMaxPool
+
+
+# input dimension
+@pytest.mark.parametrize("ifm_dim", [4, 7])
+# input channels
+@pytest.mark.parametrize("ifm_ch", [1, 3])
+# kernel size
+@pytest.mark.parametrize("k", [2, 3])
+# stride
+@pytest.mark.parametrize("stride", [1, 2])
+# padding
+@pytest.mark.parametrize("pad", [0, 1])
+# channelwise or scalar mul
+@pytest.mark.parametrize("cw", [0, 1])
+# negative mul
+@pytest.mark.parametrize("negative", [0, 1])
+def test_move_mul_past_maxpool(ifm_dim, ifm_ch, k, stride, pad, cw, negative):
+    if cw == 1:
+        mul_shape = [1, ifm_ch, 1, 1]
+    else:
+        mul_shape = [1, 1, 1, 1]
+
+    ofm_ch = ifm_ch
+    ofm_dim = compute_pool_output_dim(ifm_dim, k, stride, pad)
+
+    # set up onnx model
+    inp = helper.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+    )
+    mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, mul_shape)
+    outp = helper.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, ofm_ch, ofm_dim, ofm_dim]
+    )
+
+    Mul_node = helper.make_node("Mul", ["inp", "mul"], ["mul_out"])
+
+    Maxpool_node = helper.make_node(
+        "MaxPool",
+        ["mul_out"],
+        ["outp"],
+        kernel_shape=[k, k],
+        pads=[pad, pad, pad, pad],
+        strides=[stride, stride],
+    )
+
+    graph = helper.make_graph(
+        nodes=[Mul_node, Maxpool_node],
+        name="mulpastmaxpool_graph",
+        inputs=[inp],
+        outputs=[outp],
+        value_info=[mul],
+    )
+
+    model = helper.make_model(graph, producer_name="mulpastmaxpool-model")
+    model = ModelWrapper(model)
+    inp_values = gen_finn_dt_tensor(DataType.INT2, [1, ifm_ch, ifm_dim, ifm_dim])
+    mul_values = np.random.random_sample(mul_shape).astype(np.float32)
+    if negative == 1:
+        mul_values = mul_values * (-1)
+    model.set_initializer("mul", mul_values)
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+    idict = {"inp": inp_values}
+    odict = oxe.execute_onnx(model, idict, True)
+    out_before = odict["outp"]
+
+    # perform transformation
+    model_transformed = model.transform(MoveMulPastMaxPool())
+    odict = oxe.execute_onnx(model_transformed, idict, True)
+    out_after = odict["outp"]
+
+    assert (out_before == out_after).all()
+
+    if negative == 1:
+        assert model.graph.node[0].op_type == model_transformed.graph.node[0].op_type
+        assert model.graph.node[1].op_type == model_transformed.graph.node[1].op_type
+    else:
+        assert model.graph.node[0].op_type == model_transformed.graph.node[1].op_type
+        assert model.graph.node[1].op_type == model_transformed.graph.node[0].op_type
-- 
GitLab