diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 0b6259a61d3eb67b7b38d4c6939019ce2893a875..b46b82c77a3f1b70a3b05d87cd3c48fc1d94fd45 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -27,12 +27,14 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np +import warnings from onnx import helper as oh from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name +from finn.custom_op.registry import getCustomOp class MoveAddPastMul(Transformation): @@ -531,3 +533,67 @@ class MoveMulPastFork(MoveOpPastFork): class MoveLinearPastFork(MoveOpPastFork): def __init__(self): super().__init__(["Add", "Mul"]) + + +class MoveMaxPoolPastMultiThreshold(Transformation): + """Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + nodes = [n for n in graph.node] + for n in nodes: + node_ind += 1 + if n.op_type == "MaxPool" and not model.is_fork_node(n): + consumer = model.find_consumer(n.output[0]) + pads = get_by_name(n.attribute, "pads") + has_padding = False + if pads is not None: + pads = list(pads.ints) + has_padding = np.prod(pads) != 0 + if consumer is not None and consumer.op_type == "MultiThreshold": + mt_out = consumer.output[0] + mt_odt = model.get_tensor_datatype(mt_out) + if mt_odt.signed() and has_padding: + warnings.warn( + "Skipping padded MaxPool + signed-output MultiThreshold" + ) + continue + # check for non-decreasing thresholds and nonnegative + # scale factor in MultiThreshold + # otherwise we cannot do the reordering + T = model.get_initializer(consumer.input[1]) + T_sorted = np.sort(T, axis=1) + assert ( + T == T_sorted + ).all(), "MultiThreshold must have non-decreasing thresholds" + mt_inst = getCustomOp(consumer) + if mt_inst.get_nodeattr("out_scale") < 0: + warnings.warn("Skipping MultiThreshold with negative out_scale") + continue + + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + + # swap conections + group_in = n.input[0] + # new tensor because dims change + group_middle = model.make_new_valueinfo_name() + group_out = consumer.output[0] + + consumer.input[0] = group_in + consumer.output[0] = group_middle + + n.input[0] = group_middle + n.output[0] = group_out + + # insert them back in + graph.node.insert(node_ind - 1, consumer) + graph.node.insert(node_ind, n) + + graph_modified = True + + model = model.transform(InferShapes()) + return (model, graph_modified) diff --git a/tests/transformation/test_move_maxpool_past_multithreshold.py b/tests/transformation/test_move_maxpool_past_multithreshold.py new file mode 100644 index 0000000000000000000000000000000000000000..01005ddd16d2545f46b6b0f1f61538e9618e2a7b --- /dev/null +++ b/tests/transformation/test_move_maxpool_past_multithreshold.py @@ -0,0 +1,101 @@ +from onnx import TensorProto, helper +import numpy as np + +import finn.core.onnx_exec as oxe +from finn.core.modelwrapper import ModelWrapper +from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes + + +def get_multithreshold_rand_params(channels, num_of_thres, seed=None): + if seed is not None: + np.random.seed(seed) + steps = np.random.rand(channels, 1) * 2 + bias = np.random.rand(channels, 1) * 10 + thres = [np.arange(num_of_thres) for chn in range(channels)] + thres = ((thres - bias) * steps).astype(np.float32) + return thres + + +def test_move_maxpool_past_multithreshold(): + # generate test vectors of correct shape + ch = 64 + ifmdim = 16 + ofmdim = 16 // 4 + input_shape = (1, ch, ifmdim, ifmdim) + output_shape = (1, ch, ofmdim, ofmdim) + + top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) + top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape) + + maxpool_config = {} + maxpool_config["pads"] = [1, 1, 1, 1] + maxpool_config["kernel_shape"] = [3, 3] + maxpool_config["strides"] = [2, 2] + + value_info = [] + thres1_shape = [1, 1] + value_info += [ + helper.make_tensor_value_info("thres1", TensorProto.FLOAT, thres1_shape) + ] + + thres2_shape = [ch, 14] + value_info += [ + helper.make_tensor_value_info("thres2", TensorProto.FLOAT, thres2_shape) + ] + + nodes = [] + nodes += [helper.make_node("MaxPool", ["top_in"], ["t1"], **maxpool_config)] + nodes += [ + helper.make_node( + "MultiThreshold", + ["t1", "thres1"], + ["t2"], + domain="finn", + out_dtype="BIPOLAR", + out_bias=-1.0, + out_scale=1.0, + ) + ] + nodes += [helper.make_node("MaxPool", ["t2"], ["t3"], **maxpool_config)] + nodes += [ + helper.make_node( + "MultiThreshold", + ["t3", "thres2"], + ["top_out"], + domain="finn", + out_dtype="UINT4", + out_scale=-1.0, + ) + ] + + modelproto = helper.make_model( + helper.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=nodes, + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + + model.set_initializer("thres1", np.array([[0]])) + model.set_initializer( + "thres2", get_multithreshold_rand_params(*thres2_shape, seed=0) + ) + + # Transform + new_model = model.transform(MoveMaxPoolPastMultiThreshold()) + inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} + + # Test + assert oxe.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "MaxPool" + assert new_model.graph.node[1].op_type == "MultiThreshold" + assert new_model.graph.node[2].op_type == "MultiThreshold" + assert new_model.graph.node[3].op_type == "MaxPool" + assert len(new_model.graph.node) == 4