Skip to content
Snippets Groups Projects
test_move_maxpool_past_multithreshold.py 3.14 KiB
Newer Older
import numpy as np
from onnx import TensorProto, helper

import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold


def get_multithreshold_rand_params(channels, num_of_thres, seed=None):
    if seed is not None:
        np.random.seed(seed)
    steps = np.random.rand(channels, 1) * 2
    bias = np.random.rand(channels, 1) * 10
    thres = [np.arange(num_of_thres) for chn in range(channels)]
    thres = ((thres - bias) * steps).astype(np.float32)
    return thres


def test_move_maxpool_past_multithreshold():
    # generate test vectors of correct shape
    ch = 64
    ifmdim = 16
    ofmdim = 16 // 4
    input_shape = (1, ch, ifmdim, ifmdim)
    output_shape = (1, ch, ofmdim, ofmdim)

    top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
    top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)

    maxpool_config = {}
    maxpool_config["pads"] = [1, 1, 1, 1]
    maxpool_config["kernel_shape"] = [3, 3]
    maxpool_config["strides"] = [2, 2]

    value_info = []
    thres1_shape = [1, 1]
    value_info += [
        helper.make_tensor_value_info("thres1", TensorProto.FLOAT, thres1_shape)
    ]

    thres2_shape = [ch, 14]
    value_info += [
        helper.make_tensor_value_info("thres2", TensorProto.FLOAT, thres2_shape)
    ]

    nodes = []
    nodes += [helper.make_node("MaxPool", ["top_in"], ["t1"], **maxpool_config)]
    nodes += [
        helper.make_node(
            "MultiThreshold",
            ["t1", "thres1"],
            ["t2"],
            domain="finn.custom_op.general",
            out_dtype="BIPOLAR",
            out_scale=1.0,
        )
    ]
    nodes += [helper.make_node("MaxPool", ["t2"], ["t3"], **maxpool_config)]
    nodes += [
        helper.make_node(
            "MultiThreshold",
            ["t3", "thres2"],
            ["top_out"],
            domain="finn.custom_op.general",
            out_dtype="UINT4",
        )
    ]

    modelproto = helper.make_model(
        helper.make_graph(
            name="test",
            inputs=[top_in],
            outputs=[top_out],
            value_info=value_info,
            nodes=nodes,
        )
    )
    model = ModelWrapper(modelproto)
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())

    model.set_initializer("thres1", np.array([[0]]))
    model.set_initializer(
        "thres2", get_multithreshold_rand_params(*thres2_shape, seed=0)
    )

    # Transform
    new_model = model.transform(MoveMaxPoolPastMultiThreshold())
    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}

    # Test
    assert oxe.compare_execution(model, new_model, inp_dict)
    assert new_model.graph.node[0].op_type == "MaxPool"
    assert new_model.graph.node[1].op_type == "MultiThreshold"
    assert new_model.graph.node[2].op_type == "MultiThreshold"
    assert new_model.graph.node[3].op_type == "MaxPool"
    assert len(new_model.graph.node) == 4