diff --git a/tests/transformation/test_move_add_past_mul.py b/tests/transformation/test_move_add_past_mul.py index a0516d6fb2ff985fc112185ce99ad8facd841caf..163b9d310a5f12bd0b854f9aa46f53a549bf109e 100644 --- a/tests/transformation/test_move_add_past_mul.py +++ b/tests/transformation/test_move_add_past_mul.py @@ -60,6 +60,9 @@ def test_move_add_past_mul_single(): new_model = model.transform(MoveAddPastMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Add" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] def test_move_add_past_mul_multi(): @@ -92,3 +95,50 @@ def test_move_add_past_mul_multi(): new_model = model.transform(MoveAddPastMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[2].op_type == "Add" + assert new_model.graph.node[3].op_type == "Add" + for i in range(len(new_model.graph.node) - 1): + assert new_model.graph.node[i].output[0] == new_model.graph.node[i + 1].input[0] + + +def test_move_add_past_mul_only_if_linear(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + + value_info = [oh.make_tensor_value_info("add1_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul3_param", TensorProto.FLOAT, [1])] + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + oh.make_node("Add", ["top_in", "add1_param"], ["t1"]), + oh.make_node("Mul", ["t1", "mul1_param"], ["fork"]), + oh.make_node("Mul", ["fork", "mul2_param"], ["t3"]), + oh.make_node("Add", ["t3", "fork"], ["t4"]), + oh.make_node("Mul", ["t4", "mul3_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + model.set_initializer("add1_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul1_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul2_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul3_param", np.random.rand(2).astype(np.float32)) + new_model = model.transform(MoveAddPastMul()) + inp_dict = {"top_in": np.random.rand(2).astype(np.float32)} + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Add" + assert new_model.graph.node[2].op_type == "Mul" + assert new_model.graph.node[3].op_type == "Add" + assert new_model.graph.node[4].op_type == "Mul"