diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 0299c4f4d89d1fdd94434db77c77a0e529c86d26..a983e67750a0a860eeeb4b429f7d6b181fc84fe3 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -473,7 +473,7 @@ class AbsorbConsecutiveTransposes(Transformation): """Remove (Transpose -> Transpose) patterns when the input and output of the pattern have the same layout.""" - def Are_opposite_permutations(self, perms1, perms2): + def are_opposite_permutations(self, perms1, perms2): if len(perms1) != len(perms2): return False assert 0 <= max(perms2) < len(perms2), "invalid permutation" @@ -488,72 +488,40 @@ class AbsorbConsecutiveTransposes(Transformation): def apply(self, model): graph = model.graph graph_modified = False - for n in graph.node: - if n.op_type == "Transpose": - if model.is_fork_node(n): - next_nodes = model.find_direct_successors(n) - perms1 = list(get_by_name(n.attribute, "perm").ints) - - # check if all nodes after fork are opposite transposes - all_opposite_transposes = True - for next_node in next_nodes: - if next_node is not None and next_node.op_type == "Transpose": - perms2 = list(get_by_name(next_node.attribute, "perm").ints) - if not self.Are_opposite_permutations(perms1, perms2): - all_opposite_transposes = False - break - else: - all_opposite_transposes = False - break - - if not all_opposite_transposes: - continue - - prod = model.find_producer(n.input[0]) - for next_node in next_nodes: - # connect next_node's consumer input to n's producer output - # TODO implement this to allow for forks as producers and - # joins as consumers - cons = model.find_consumer(next_node.output[0]) - cons.input[0] = prod.output[0] - - # remove consumer transpose - graph.node.remove(next_node) - - # remove producer transpose - graph.node.remove(n) - graph_modified = True - - else: - next_node = model.find_consumer(n.output[0]) + for node in graph.node: + if node.op_type == "Transpose": + next_nodes = model.find_consumers(node.output[0]) + perms1 = list(get_by_name(node.attribute, "perm").ints) + # check if all nodes after fork are opposite transposes + all_opposite_transposes = True + for next_node in next_nodes: if next_node is not None and next_node.op_type == "Transpose": - perms1 = list(get_by_name(n.attribute, "perm").ints) perms2 = list(get_by_name(next_node.attribute, "perm").ints) - if self.Are_opposite_permutations(perms1, perms2): - - # connect next_node's consumer input to n's producer output - # TODO implement this to allow for forks as producers - consumers = model.find_direct_successors(next_node) - prod = model.find_producer(n.input[0]) - if prod is not None: - for cons in consumers: - for cons_in in cons.input: - if cons_in == next_node.output[0]: - prod.output[0] = cons_in - break - else: - # n.input[0] is top-level graph input - # wire consumers directly to that - for cons in consumers: - for i, iname in enumerate(cons.input): - if iname == next_node.output[0]: - cons.input[i] = n.input[0] - - # remove both transposes - graph.node.remove(n) - graph.node.remove(next_node) + if not self.are_opposite_permutations(perms1, perms2): + all_opposite_transposes = False + break + else: + all_opposite_transposes = False + break + if not all_opposite_transposes: + continue + source_tensor = node.input[0] + for next_node in next_nodes: + # connect next_node's consumers' appropriate input to n's input + # TODO how to handle top-level outputs if any? + nextnode_out = next_node.output[0] + assert nextnode_out not in [x.name for x in model.graph.output] + consumers = model.find_consumers(nextnode_out) + for cons in consumers: + for i, iname in enumerate(cons.input): + if iname == nextnode_out: + cons.input[i] = source_tensor + # remove consumer transpose + graph.node.remove(next_node) + # remove producer transpose + graph.node.remove(node) + graph_modified = True - graph_modified = True if graph_modified: model = model.transform(InferDataTypes()) return (model, graph_modified) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 9ff8a2173ce81e2a19c56bbd20a326759c3b9df2..3e815c1537353cc2be970a2068d4ded30cc48bc8 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -553,6 +553,8 @@ class MoveLinearPastEltwiseAdd(Transformation): # Other transform should handle that if prod0 is None or prod1 is None or (prod0 == prod1): continue + if len(prod0.input) < 2 or len(prod1.input) < 2: + continue init0 = model.get_initializer(prod0.input[1]) init1 = model.get_initializer(prod1.input[1]) # if either initializer is None, skip @@ -728,9 +730,10 @@ class MoveOpPastFork(Transformation): can be merged with nodes in the branches """ - def __init__(self, op_name_list): + def __init__(self, op_name_list, get_attrs_fxn=lambda x: {}): super().__init__() self.ops_to_move = op_name_list + self.get_attrs_fxn = get_attrs_fxn def apply(self, model): graph = model.graph @@ -747,9 +750,10 @@ class MoveOpPastFork(Transformation): # Restrict this transform to operations with constant parameters # Assuming parameters is in input 1 - op_init_param = model.get_initializer(n.input[1]) - if op_init_param is None: - continue + if len(n.input) > 1: + op_init_param = model.get_initializer(n.input[1]) + else: + op_init_param = None # Check case when branches are empty and go # to the same node @@ -766,16 +770,20 @@ class MoveOpPastFork(Transformation): for consumer_node in consumers[1:]: # create new node - new_param_name = model.make_new_valueinfo_name() new_output_tensor_name = model.make_new_valueinfo_name() + if op_init_param is None: + new_inp_list = [n.input[0]] + else: + new_param_name = model.make_new_valueinfo_name() + new_inp_list = [n.input[0], new_param_name] + model.set_initializer(new_param_name, op_init_param) + attrs = self.get_attrs_fxn(n) + # TODO use copy of original node instead to get attrs? new_node = oh.make_node( - n.op_type, - [n.input[0], new_param_name], - [new_output_tensor_name], + n.op_type, new_inp_list, [new_output_tensor_name], **attrs ) graph.node.insert(node_ind, new_node) node_ind += 1 - model.set_initializer(new_param_name, op_init_param) # change consumer input tensor graph.node.remove(consumer_node) @@ -811,6 +819,13 @@ class MoveLinearPastFork(MoveOpPastFork): super().__init__(["Add", "Mul"]) +class MoveTransposePastFork(MoveOpPastFork): + def __init__(self): + super().__init__( + ["Transpose"], lambda x: {"perm": get_by_name(x.attribute, "perm").ints} + ) + + class MoveMaxPoolPastMultiThreshold(Transformation): """Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph.""" diff --git a/tests/transformation/streamline/test_absorb_opposite_transposes.py b/tests/transformation/streamline/test_absorb_opposite_transposes.py index 51ea5edfc420bf935de3e196df1b150934782a91..88cbd5657e2ae6c0946e59186c25d935595ad2ff 100644 --- a/tests/transformation/streamline/test_absorb_opposite_transposes.py +++ b/tests/transformation/streamline/test_absorb_opposite_transposes.py @@ -29,8 +29,7 @@ import pytest import numpy as np -import onnx.helper as oh -from onnx import TensorProto +import onnx.parser as oprs from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.infer_shapes import InferShapes @@ -41,39 +40,44 @@ from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes @pytest.mark.streamline def test_absorb_opposite_transposes(): np.random.seed(0) - input_shape = [1, 3, 4, 2] - top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) - top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape) - value_info = [oh.make_tensor_value_info("add_param_0", TensorProto.FLOAT, [1])] - value_info += [oh.make_tensor_value_info("add_param_1", TensorProto.FLOAT, [1])] - value_info += [oh.make_tensor_value_info("mul_param_0", TensorProto.FLOAT, [1])] - modelproto = oh.make_model( - oh.make_graph( - name="test", - inputs=[top_in], - outputs=[top_out], - value_info=value_info, - nodes=[ - oh.make_node("Add", ["top_in", "add_param_0"], ["t0"]), - oh.make_node("Transpose", ["t0"], ["t1"], perm=[0, 2, 3, 1]), - oh.make_node("Transpose", ["t1"], ["t2"], perm=[0, 3, 1, 2]), - oh.make_node("Add", ["t2", "add_param_1"], ["t3"]), - oh.make_node("Transpose", ["t3"], ["t4"], perm=[0, 2, 3, 1]), - oh.make_node("Transpose", ["t4"], ["t5"], perm=[0, 3, 1, 2]), - oh.make_node("Add", ["t5", "t2"], ["t6"]), - oh.make_node("Mul", ["t6", "mul_param_0"], ["top_out"]), - ], - ) - ) - model = ModelWrapper(modelproto) + shp = [1, 3, 4, 2] + shp_str = str(shp) + input = f""" + < + ir_version: 7, + opset_import: ["" : 9] + > + agraph (float{shp_str} in0) => (float{shp_str} out0) + < + float[1] add0_param = {{1.0}}, + float[1] add1_param = {{3.0}}, + float[1] mul0_param = {{2.0}} + > + {{ + add0_out = Add(in0, add0_param) + t0_out = Transpose<perm=[0,2,3,1]>(add0_out) + t1_out = Transpose<perm=[0,3,1,2]>(t0_out) + add1_out = Add(t1_out, add1_param) + t2_out = Transpose<perm=[0,2,3,1]>(add1_out) + t3_out = Transpose<perm=[0,3,1,2]>(t2_out) + add2_out = Add(t1_out, t3_out) + t4_out = Transpose<perm=[0,2,3,1]>(add2_out) + t5_out = Transpose<perm=[0,3,1,2]>(t4_out) + t6_out = Transpose<perm=[0,3,1,2]>(t4_out) + m0_out = Mul(t5_out, mul0_param) + m1_out = Mul(t6_out, mul0_param) + out0 = Mul(m0_out, m1_out) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) model = model.transform(InferShapes()) - model.set_initializer("add_param_0", np.asarray([1], dtype=np.float32)) - model.set_initializer("add_param_1", np.asarray([3], dtype=np.float32)) - model.set_initializer("mul_param_0", np.asarray([2], dtype=np.float32)) + model.save("dbg.onnx") new_model = model.transform(AbsorbConsecutiveTransposes()) new_model = new_model.transform(InferShapes()) - inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} + new_model.save("newdbg.onnx") + inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)} assert ox.compare_execution(model, model, inp_dict) - assert len(new_model.graph.node) == 4 + assert len(new_model.graph.node) == 6 for n in new_model.graph.node: assert new_model.graph.node[0].op_type != "Transpose" diff --git a/tests/transformation/streamline/test_move_past_fork.py b/tests/transformation/streamline/test_move_past_fork.py index 5064fa3fca869a245c87cf0c1680d1357e5de60b..7e77d7f9b3502429f08c40558e330b6261d0dbad 100644 --- a/tests/transformation/streamline/test_move_past_fork.py +++ b/tests/transformation/streamline/test_move_past_fork.py @@ -28,80 +28,113 @@ import pytest import numpy as np -from onnx import TensorProto, helper +import onnx.parser as oprs from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.general import GiveUniqueNodeNames from qonnx.transformation.infer_shapes import InferShapes +from qonnx.util.basic import get_by_name import finn.core.onnx_exec as oxe -from finn.transformation.streamline.reorder import MoveLinearPastFork +from finn.transformation.streamline.reorder import ( + MoveLinearPastFork, + MoveTransposePastFork, +) + + +@pytest.mark.streamline +def test_move_past_fork_transpose(): + shp = [1, 3, 32, 32] + shp_str = str(shp) + input = f""" + < + ir_version: 7, + opset_import: ["" : 9] + > + agraph (float{shp_str} in0) => (float{shp_str} out0) + {{ + t0_out = Transpose<perm=[0,2,3,1]>(in0) + t1_out = Transpose<perm=[0,3,1,2]>(t0_out) + t2_out = Transpose<perm=[0,3,1,2]>(t0_out) + out0 = Add(t1_out, t2_out) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) + model = model.transform(InferShapes()) + new_model = model.transform(MoveTransposePastFork()) + new_model = new_model.transform(GiveUniqueNodeNames()) + nodes = new_model.graph.node + assert oxe.compare_execution( + model, new_model, {"in0": np.random.rand(*shp).astype(np.float32)} + ) + assert len(nodes) == 5 + assert not new_model.is_fork_node(get_by_name(nodes, "Transpose_0")) @pytest.mark.streamline @pytest.mark.parametrize("ch", [64, 1]) # ifmdim @pytest.mark.parametrize("ifmdim", [-1, 7]) -def test_move_past_fork(ch, ifmdim): - # generate test vectors of correct shape +def test_move_past_fork_linear(ch, ifmdim): if ifmdim == -1: - input_shape = (1, ch) + shp = [1, ch] else: - input_shape = (1, ch, ifmdim, ifmdim) + shp = [1, ch, ifmdim, ifmdim] + shp_str = str(shp) + input = f""" + < + ir_version: 7, + opset_import: ["" : 9] + > + agraph (float{shp_str} in0) => (float{shp_str} out0) + < + float{shp_str} add0_param, + float{shp_str} mul_shared_param, + float{shp_str} add2_param, + float{shp_str} mul2_param, + float{shp_str} add3_param, + float{shp_str} add4_param, + float{shp_str} mul3_param, + float{shp_str} add6_param + > + {{ - top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) - top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape) - - num_of_params = 8 - value_info = [] - for i in range(num_of_params): - value_info += [ - helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape) - ] - - add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"]) - mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"]) - add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"]) - mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"]) - modelproto = helper.make_model( - helper.make_graph( - name="test", - inputs=[top_in], - outputs=[top_out], - value_info=value_info, - nodes=[ - # fork1 - add_1_to_move, - helper.make_node("Mul", ["fork1", "p1"], ["t2"]), - helper.make_node("Mul", ["fork1", "p2"], ["t3"]), - helper.make_node("Add", ["t2", "t3"], ["t4"]), - helper.make_node("Add", ["t4", "p3"], ["t5"]), - # fork2 - mul_1_to_move, - add_2_to_move, - helper.make_node("Add", ["fork2", "p6"], ["t7"]), - helper.make_node("Add", ["t6", "t7"], ["t8"]), - # empty branches: do nothing - mul_1_not_to_move, - helper.make_node("Add", ["fork3", "fork3"], ["top_out"]), - ], - ) - ) - model = ModelWrapper(modelproto) + add0_out = Add(in0, add0_param) + mul0_out = Mul(add0_out, mul_shared_param) + mul1_out = Mul(add0_out, mul_shared_param) + add1_out = Add(mul0_out, mul1_out) + add2_out = Add(add1_out, add2_param) + mul2_out = Mul(add2_out, mul2_param) + add3_out = Add(mul2_out, add3_param) + add4_out = Add(mul2_out, add4_param) + add5_out = Add(add3_out, add4_out) + mul3_out = Mul(add5_out, mul3_param) + out0 = Add(mul3_out, add6_param) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) model = model.transform(InferShapes()) np.random.seed(0) - for i in range(num_of_params): - model.set_initializer( - "p" + str(i), np.random.rand(*input_shape).astype(np.float32) - ) - + for tensor_name in model.get_all_tensor_names(): + if tensor_name.endswith("_param"): + pshape = model.get_tensor_shape(tensor_name) + model.set_initializer( + tensor_name, np.random.rand(*pshape).astype(np.float32) + ) + model = model.transform(GiveUniqueNodeNames()) # Transform new_model = model.transform(MoveLinearPastFork()) - inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} - + new_model = new_model.transform(GiveUniqueNodeNames()) + inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)} # Test assert oxe.compare_execution(model, new_model, inp_dict) - assert not new_model.is_fork_node(add_1_to_move) - assert not new_model.is_fork_node(mul_1_to_move) - assert not new_model.is_fork_node(add_2_to_move) - assert new_model.is_fork_node(mul_1_not_to_move) + nodes = new_model.graph.node + assert len(new_model.get_nodes_by_op_type("Add")) == 9 + assert len(new_model.get_nodes_by_op_type("Mul")) == 5 + assert not new_model.is_fork_node(get_by_name(nodes, "Add_0")) + assert new_model.is_join_node(get_by_name(nodes, "Add_2")) + assert not new_model.is_fork_node(get_by_name(nodes, "Mul_2")) + assert not new_model.is_join_node(get_by_name(nodes, "Add_5")) assert len(new_model.graph.node) == 14