diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index b23f9f14909a5bd93ae24b34ef65304dafc7e0c1..7163a95c4dbbe5c8bcee4ebeea87c5e9611c179e 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -40,6 +40,7 @@ from finn.core.datatype import DataType from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name from finn.custom_op.registry import getCustomOp +from finn.transformation.general import SortGraph class MoveAddPastMul(Transformation): @@ -1039,3 +1040,77 @@ class MoveTransposePastScalarMul(Transformation): model = model.transform(InferDataLayouts()) model = model.transform(InferShapes()) return (model, graph_modified) + + +class MoveIdenticalOpPastJoinOp(Transformation): + """ + Move identical operations on different branches past the common join node. + This transformation assumes that the identical operations only change the + data layout. For linear operations, see the transformation MoveLinearPastEltwiseAdd. + Specifically, this transformation matches and transforms the following patterns: + f(x) + f(y) -> f(x + y) + where f(.) is currently only supporting 'Transpose', and an 'Add' node is + the join node. + """ + + def __init__(self, identical_op_list, join_node_list): + super().__init__() + self.ops_to_move = identical_op_list + self.join_node_op = join_node_list + + def move_node(self, model, n, prod0, prod1): + # Found! move one of the identical_ops to output, remove the other one + identical_op0_in0 = prod0.input[0] + identical_op1_in0 = prod1.input[0] + add_in0 = n.input[0] + add_out = n.output[0] + + # Rewire + n.input[0] = identical_op0_in0 + n.input[1] = identical_op1_in0 + + # Output tensor of the join node must have the same shape as + # its input tensor (original shape is preserved) + new_shape = model.get_tensor_shape(identical_op0_in0) + + # Set new tensor shape + model.set_tensor_shape(tensor_name=add_in0, tensor_shape=new_shape) + + n.output[0] = add_in0 + prod0.input[0] = add_in0 + prod0.output[0] = add_out + + model.graph.node.remove(prod1) + + def apply(self, model): + graph = model.graph + graph_modified = False + for n in graph.node: + if n.op_type in self.join_node_op and model.is_join_node(n): + in0 = n.input[0] + in1 = n.input[1] + if in0 is None or in1 is None: + continue + + prod0 = model.find_producer(in0) + prod1 = model.find_producer(in1) + # Checks if the join node is preceded by + # two different, but identical operations + if prod0 == prod1: + continue + + identical_op = prod0.op_type == prod1.op_type + + if identical_op and prod0.op_type in self.ops_to_move: + self.move_node(model, n, prod0, prod1) + graph_modified = True + + if graph_modified: + model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False) + + return (model, graph_modified) + + +class MoveTransposePastJoinAdd(MoveIdenticalOpPastJoinOp): + def __init__(self): + super().__init__(["Transpose"], ["Add"]) diff --git a/tests/transformation/streamline/test_move_identical_op_past_join_op.py b/tests/transformation/streamline/test_move_identical_op_past_join_op.py new file mode 100644 index 0000000000000000000000000000000000000000..94eb52835b1800a839e5a9792e9cf1d7be1e681d --- /dev/null +++ b/tests/transformation/streamline/test_move_identical_op_past_join_op.py @@ -0,0 +1,94 @@ +import pytest + +from onnx import helper as oh +from onnx import TensorProto + +from finn.core.modelwrapper import ModelWrapper +from finn.transformation.streamline.reorder import MoveTransposePastJoinAdd +from finn.util.basic import gen_finn_dt_tensor +import finn.core.onnx_exec as oxe + + +def create_model(perm): + if perm == [0, 3, 1, 2]: + in_shape = [1, 128, 1, 256] + out_shape = [1, 256, 128, 1] + if perm == [0, 2, 3, 1]: + in_shape = [1, 256, 128, 1] + out_shape = [1, 128, 1, 256] + + Transpose1_node = oh.make_node( + "Transpose", inputs=["in_transpose1"], outputs=["out_transpose1"], perm=perm + ) + + Transpose2_node = oh.make_node( + "Transpose", inputs=["in_transpose2"], outputs=["out_transpose2"], perm=perm + ) + + Join1_node = oh.make_node( + "Add", inputs=["out_transpose1", "out_transpose2"], outputs=["out_join1"] + ) + + in_transpose1 = oh.make_tensor_value_info( + "in_transpose1", TensorProto.FLOAT, in_shape + ) + in_transpose2 = oh.make_tensor_value_info( + "in_transpose2", TensorProto.FLOAT, in_shape + ) + out_transpose1 = oh.make_tensor_value_info( + "out_transpose1", TensorProto.FLOAT, out_shape + ) + out_transpose2 = oh.make_tensor_value_info( + "out_transpose2", TensorProto.FLOAT, out_shape + ) + out_join1 = oh.make_tensor_value_info("out_join1", TensorProto.FLOAT, out_shape) + + graph = oh.make_graph( + nodes=[Transpose1_node, Transpose2_node, Join1_node], + name="test_graph", + inputs=[in_transpose1, in_transpose2], + outputs=[out_join1], + value_info=[ + out_transpose1, + out_transpose2, + ], + ) + + onnx_model = oh.make_model(graph, producer_name="test_model") + model = ModelWrapper(onnx_model) + + return model + + +# Permutation of transpose node +@pytest.mark.parametrize("perm", [[0, 3, 1, 2], [0, 2, 3, 1]]) +def test_move_identical_op_past_join_op(perm): + model = create_model(perm) + + # Create input data + input0_tensor_name = model.graph.input[0].name + input1_tensor_name = model.graph.input[1].name + + # Note: it is assumed that both tensors have the same shape and data type + input_shape = model.get_tensor_shape(input0_tensor_name) + input_dtype = model.get_tensor_datatype(input0_tensor_name) + input_val = gen_finn_dt_tensor(input_dtype, input_shape) + input_dict = {} + input_dict[input0_tensor_name] = input_val + input_dict[input1_tensor_name] = input_val + + model_transformed = model.transform(MoveTransposePastJoinAdd()) + + assert oxe.compare_execution(model, model_transformed, input_dict) + + # Check if order changed + node0_input0_model = model.find_consumers(model.graph.input[0].name)[0].op_type + node1_input1_model = model.find_consumers(model.graph.input[1].name)[0].op_type + node0_input0_model_transformed = model_transformed.find_consumers( + model_transformed.graph.input[0].name + )[0].op_type + node1_input1_model_transformed = model_transformed.find_consumers( + model_transformed.graph.input[1].name + )[0].op_type + assert node0_input0_model != node0_input0_model_transformed + assert node1_input1_model != node1_input1_model_transformed