Skip to content
Snippets Groups Projects
test_move_mul_past_maxpool.py 3.05 KiB
Newer Older
import numpy as np
from onnx import TensorProto, helper

import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.custom_op.general.maxpoolnhwc import compute_pool_output_dim
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline.reorder import MoveMulPastMaxPool
from finn.util.basic import gen_finn_dt_tensor


# input dimension
@pytest.mark.parametrize("ifm_dim", [4, 7])
# input channels
@pytest.mark.parametrize("ifm_ch", [1, 3])
# kernel size
@pytest.mark.parametrize("k", [2, 3])
# stride
@pytest.mark.parametrize("stride", [1, 2])
# padding
@pytest.mark.parametrize("pad", [0, 1])
# channelwise or scalar mul
@pytest.mark.parametrize("cw", [0, 1])
# negative mul
@pytest.mark.parametrize("negative", [0, 1])
def test_move_mul_past_maxpool(ifm_dim, ifm_ch, k, stride, pad, cw, negative):
    if cw == 1:
        mul_shape = [1, ifm_ch, 1, 1]
    else:
        mul_shape = [1, 1, 1, 1]

    ofm_ch = ifm_ch
    ofm_dim = compute_pool_output_dim(ifm_dim, k, stride, pad)

    # set up onnx model
    inp = helper.make_tensor_value_info(
        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
    )
    mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, mul_shape)
    outp = helper.make_tensor_value_info(
        "outp", TensorProto.FLOAT, [1, ofm_ch, ofm_dim, ofm_dim]
    )

    Mul_node = helper.make_node("Mul", ["inp", "mul"], ["mul_out"])

    Maxpool_node = helper.make_node(
        "MaxPool",
        ["mul_out"],
        ["outp"],
        kernel_shape=[k, k],
        pads=[pad, pad, pad, pad],
        strides=[stride, stride],
    )

    graph = helper.make_graph(
        nodes=[Mul_node, Maxpool_node],
        name="mulpastmaxpool_graph",
        inputs=[inp],
        outputs=[outp],
        value_info=[mul],
    )

    model = helper.make_model(graph, producer_name="mulpastmaxpool-model")
    model = ModelWrapper(model)
    inp_values = gen_finn_dt_tensor(DataType.INT2, [1, ifm_ch, ifm_dim, ifm_dim])
    mul_values = np.random.random_sample(mul_shape).astype(np.float32)
    if negative == 1:
        mul_values = mul_values * (-1)
    model.set_initializer("mul", mul_values)
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())
    idict = {"inp": inp_values}
    odict = oxe.execute_onnx(model, idict, True)
    out_before = odict["outp"]

    # perform transformation
    model_transformed = model.transform(MoveMulPastMaxPool())
    odict = oxe.execute_onnx(model_transformed, idict, True)
    out_after = odict["outp"]

    assert (out_before == out_after).all()

    if negative == 1:
        assert model.graph.node[0].op_type == model_transformed.graph.node[0].op_type
        assert model.graph.node[1].op_type == model_transformed.graph.node[1].op_type
    else:
        assert model.graph.node[0].op_type == model_transformed.graph.node[1].op_type
        assert model.graph.node[1].op_type == model_transformed.graph.node[0].op_type