diff --git a/tests/transformation/streamline/test_move_past_fork.py b/tests/transformation/streamline/test_move_past_fork.py index 5064fa3fca869a245c87cf0c1680d1357e5de60b..543a43d64d209edbfe5cb0944e7fcb084b96838a 100644 --- a/tests/transformation/streamline/test_move_past_fork.py +++ b/tests/transformation/streamline/test_move_past_fork.py @@ -28,9 +28,11 @@ import pytest import numpy as np -from onnx import TensorProto, helper +import onnx.parser as oprs from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.general import GiveUniqueNodeNames from qonnx.transformation.infer_shapes import InferShapes +from qonnx.util.basic import get_by_name import finn.core.onnx_exec as oxe from finn.transformation.streamline.reorder import MoveLinearPastFork @@ -41,67 +43,65 @@ from finn.transformation.streamline.reorder import MoveLinearPastFork # ifmdim @pytest.mark.parametrize("ifmdim", [-1, 7]) def test_move_past_fork(ch, ifmdim): - # generate test vectors of correct shape if ifmdim == -1: - input_shape = (1, ch) + shp = [1, ch] else: - input_shape = (1, ch, ifmdim, ifmdim) + shp = [1, ch, ifmdim, ifmdim] + shp_str = str(shp) + input = f""" + < + ir_version: 7, + opset_import: ["" : 9] + > + agraph (float{shp_str} in0) => (float{shp_str} out0) + < + float{shp_str} add0_param, + float{shp_str} mul_shared_param, + float{shp_str} add2_param, + float{shp_str} mul2_param, + float{shp_str} add3_param, + float{shp_str} add4_param, + float{shp_str} mul3_param, + float{shp_str} add6_param + > + {{ - top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) - top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape) - - num_of_params = 8 - value_info = [] - for i in range(num_of_params): - value_info += [ - helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape) - ] - - add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"]) - mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"]) - add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"]) - mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"]) - modelproto = helper.make_model( - helper.make_graph( - name="test", - inputs=[top_in], - outputs=[top_out], - value_info=value_info, - nodes=[ - # fork1 - add_1_to_move, - helper.make_node("Mul", ["fork1", "p1"], ["t2"]), - helper.make_node("Mul", ["fork1", "p2"], ["t3"]), - helper.make_node("Add", ["t2", "t3"], ["t4"]), - helper.make_node("Add", ["t4", "p3"], ["t5"]), - # fork2 - mul_1_to_move, - add_2_to_move, - helper.make_node("Add", ["fork2", "p6"], ["t7"]), - helper.make_node("Add", ["t6", "t7"], ["t8"]), - # empty branches: do nothing - mul_1_not_to_move, - helper.make_node("Add", ["fork3", "fork3"], ["top_out"]), - ], - ) - ) - model = ModelWrapper(modelproto) + add0_out = Add(in0, add0_param) + mul0_out = Mul(add0_out, mul_shared_param) + mul1_out = Mul(add0_out, mul_shared_param) + add1_out = Add(mul0_out, mul1_out) + add2_out = Add(add1_out, add2_param) + mul2_out = Mul(add2_out, mul2_param) + add3_out = Add(mul2_out, add3_param) + add4_out = Add(mul2_out, add4_param) + add5_out = Add(add3_out, add4_out) + mul3_out = Mul(add5_out, mul3_param) + out0 = Add(mul3_out, add6_param) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) model = model.transform(InferShapes()) np.random.seed(0) - for i in range(num_of_params): - model.set_initializer( - "p" + str(i), np.random.rand(*input_shape).astype(np.float32) - ) - + for tensor_name in model.get_all_tensor_names(): + if tensor_name.endswith("_param"): + pshape = model.get_tensor_shape(tensor_name) + model.set_initializer( + tensor_name, np.random.rand(*pshape).astype(np.float32) + ) + model = model.transform(GiveUniqueNodeNames()) # Transform new_model = model.transform(MoveLinearPastFork()) - inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} - + new_model = new_model.transform(GiveUniqueNodeNames()) + inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)} # Test assert oxe.compare_execution(model, new_model, inp_dict) - assert not new_model.is_fork_node(add_1_to_move) - assert not new_model.is_fork_node(mul_1_to_move) - assert not new_model.is_fork_node(add_2_to_move) - assert new_model.is_fork_node(mul_1_not_to_move) + nodes = new_model.graph.node + assert len(new_model.get_nodes_by_op_type("Add")) == 9 + assert len(new_model.get_nodes_by_op_type("Mul")) == 5 + assert not new_model.is_fork_node(get_by_name(nodes, "Add_0")) + assert new_model.is_join_node(get_by_name(nodes, "Add_2")) + assert not new_model.is_fork_node(get_by_name(nodes, "Mul_2")) + assert not new_model.is_join_node(get_by_name(nodes, "Add_5")) assert len(new_model.graph.node) == 14