From 6463cacb3ae4cdd5ef5f40bb708c07a5e6471706 Mon Sep 17 00:00:00 2001 From: Tobi-Alonso <tobi.alonso@gmail.com> Date: Mon, 18 May 2020 12:56:26 +0100 Subject: [PATCH] [TEST] Add test linear move_add_past_mul for non_linear graphs and improve previous test by checking if doing job --- .../transformation/test_move_add_past_mul.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/transformation/test_move_add_past_mul.py b/tests/transformation/test_move_add_past_mul.py index a0516d6fb..163b9d310 100644 --- a/tests/transformation/test_move_add_past_mul.py +++ b/tests/transformation/test_move_add_past_mul.py @@ -60,6 +60,9 @@ def test_move_add_past_mul_single(): new_model = model.transform(MoveAddPastMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Add" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] def test_move_add_past_mul_multi(): @@ -92,3 +95,50 @@ def test_move_add_past_mul_multi(): new_model = model.transform(MoveAddPastMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[2].op_type == "Add" + assert new_model.graph.node[3].op_type == "Add" + for i in range(len(new_model.graph.node) - 1): + assert new_model.graph.node[i].output[0] == new_model.graph.node[i + 1].input[0] + + +def test_move_add_past_mul_only_if_linear(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + + value_info = [oh.make_tensor_value_info("add1_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul3_param", TensorProto.FLOAT, [1])] + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + oh.make_node("Add", ["top_in", "add1_param"], ["t1"]), + oh.make_node("Mul", ["t1", "mul1_param"], ["fork"]), + oh.make_node("Mul", ["fork", "mul2_param"], ["t3"]), + oh.make_node("Add", ["t3", "fork"], ["t4"]), + oh.make_node("Mul", ["t4", "mul3_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + model.set_initializer("add1_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul1_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul2_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul3_param", np.random.rand(2).astype(np.float32)) + new_model = model.transform(MoveAddPastMul()) + inp_dict = {"top_in": np.random.rand(2).astype(np.float32)} + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Add" + assert new_model.graph.node[2].op_type == "Mul" + assert new_model.graph.node[3].op_type == "Add" + assert new_model.graph.node[4].op_type == "Mul" -- GitLab