diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index ec58084df1589aaa4db5e154832b3fbc4eddb9de..0b6259a61d3eb67b7b38d4c6939019ce2893a875 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -444,3 +444,90 @@ class MakeMaxPoolNHWC(Transformation): graph.node.insert(node_ind - 1, consumer) graph_modified = True return (model, graph_modified) + + +class MoveOpPastFork(Transformation): + """Move node operations past graph forks. Used when a node before a fork + can be merged with nodes in the branches + """ + + def __init__(self, op_name_list): + super().__init__() + self.ops_to_move = op_name_list + + def apply(self, model): + graph = model.graph + graph_modified = False + nodes = [n for n in graph.node] + node_ind = 0 + for n in nodes: + node_ind += 1 + if ( + n.op_type in self.ops_to_move + and model.is_fork_node(n) + and not model.is_join_node(n) + ): + + # Restrict this transform to operations with constant parameters + # Assuming parameters is in input 1 + op_init_param = model.get_initializer(n.input[1]) + if op_init_param is None: + continue + + # Check case when branches are empty and go + # to the same node + consumers = model.find_consumers(n.output[0]) + unique_consumer = True + for consum_node in consumers[1:]: + if consumers[0] != consum_node: + unique_consumer = False + break + + if unique_consumer: + continue + + for consumer_node in consumers[1:]: + # create new node + new_param_name = model.make_new_valueinfo_name() + new_output_tensor_name = model.make_new_valueinfo_name() + new_node = oh.make_node( + n.op_type, + [n.input[0], new_param_name], + [new_output_tensor_name], + ) + graph.node.insert(node_ind, new_node) + node_ind += 1 + model.set_initializer(new_param_name, op_init_param) + + # change consumer input tensor + graph.node.remove(consumer_node) + for idx, consumer_input in enumerate(consumer_node.input): + if consumer_input == n.output[0]: + consumer_node.input[idx] = new_output_tensor_name + break + else: + raise Exception( + "Consumer should have the current node output as input" + ) + + graph.node.insert(node_ind, consumer_node) + + graph_modified = True + + model = model.transform(InferShapes()) + return (model, graph_modified) + + +class MoveAddPastFork(MoveOpPastFork): + def __init__(self): + super().__init__(["Add"]) + + +class MoveMulPastFork(MoveOpPastFork): + def __init__(self): + super().__init__(["Mul"]) + + +class MoveLinearPastFork(MoveOpPastFork): + def __init__(self): + super().__init__(["Add", "Mul"]) diff --git a/tests/transformation/test_move_past_fork.py b/tests/transformation/test_move_past_fork.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d37bd60c9e2580ca4499daafa8693f39fec810 --- /dev/null +++ b/tests/transformation/test_move_past_fork.py @@ -0,0 +1,79 @@ +from onnx import TensorProto, helper +import numpy as np + +import finn.core.onnx_exec as oxe +from finn.core.modelwrapper import ModelWrapper +from finn.transformation.streamline.reorder import MoveLinearPastFork +from finn.transformation.infer_shapes import InferShapes + +import pytest + + +@pytest.mark.parametrize("ch", [64, 1]) +# ifmdim +@pytest.mark.parametrize("ifmdim", [-1, 7]) +def test_move_past_fork(ch, ifmdim): + # generate test vectors of correct shape + if ifmdim == -1: + input_shape = (1, ch) + else: + input_shape = (1, ch, ifmdim, ifmdim) + + top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) + top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape) + + num_of_params = 8 + value_info = [] + for i in range(num_of_params): + value_info += [ + helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape) + ] + + add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"]) + mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"]) + add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"]) + mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"]) + modelproto = helper.make_model( + helper.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + # fork1 + add_1_to_move, + helper.make_node("Mul", ["fork1", "p1"], ["t2"]), + helper.make_node("Mul", ["fork1", "p2"], ["t3"]), + helper.make_node("Add", ["t2", "t3"], ["t4"]), + helper.make_node("Add", ["t4", "p3"], ["t5"]), + # fork2 + mul_1_to_move, + add_2_to_move, + helper.make_node("Add", ["fork2", "p6"], ["t7"]), + helper.make_node("Add", ["t6", "t7"], ["t8"]), + # empty branches: do nothing + mul_1_not_to_move, + helper.make_node("Add", ["fork3", "fork3"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + for i in range(num_of_params): + model.set_initializer( + "p" + str(i), np.random.rand(*input_shape).astype(np.float32) + ) + + # Transform + new_model = model.transform(MoveLinearPastFork()) + inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} + + # Test + assert oxe.compare_execution(model, new_model, inp_dict) + assert not new_model.is_fork_node(add_1_to_move) + assert not new_model.is_fork_node(mul_1_to_move) + assert not new_model.is_fork_node(add_2_to_move) + assert new_model.is_fork_node(mul_1_not_to_move) + assert len(new_model.graph.node) == 14