Skip to content
Snippets Groups Projects
Commit 6463cacb authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[TEST] Add test linear move_add_past_mul for non_linear graphs and improve...

[TEST] Add test linear move_add_past_mul for non_linear graphs and improve previous test by checking if doing job
parent 2a5793b8
No related branches found
No related tags found
No related merge requests found
......@@ -60,6 +60,9 @@ def test_move_add_past_mul_single():
new_model = model.transform(MoveAddPastMul())
inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Mul"
assert new_model.graph.node[1].op_type == "Add"
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
def test_move_add_past_mul_multi():
......@@ -92,3 +95,50 @@ def test_move_add_past_mul_multi():
new_model = model.transform(MoveAddPastMul())
inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Mul"
assert new_model.graph.node[1].op_type == "Mul"
assert new_model.graph.node[2].op_type == "Add"
assert new_model.graph.node[3].op_type == "Add"
for i in range(len(new_model.graph.node) - 1):
assert new_model.graph.node[i].output[0] == new_model.graph.node[i + 1].input[0]
def test_move_add_past_mul_only_if_linear():
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2])
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2])
value_info = [oh.make_tensor_value_info("add1_param", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("mul3_param", TensorProto.FLOAT, [1])]
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
oh.make_node("Add", ["top_in", "add1_param"], ["t1"]),
oh.make_node("Mul", ["t1", "mul1_param"], ["fork"]),
oh.make_node("Mul", ["fork", "mul2_param"], ["t3"]),
oh.make_node("Add", ["t3", "fork"], ["t4"]),
oh.make_node("Mul", ["t4", "mul3_param"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
model.set_initializer("add1_param", np.random.rand(2).astype(np.float32))
model.set_initializer("mul1_param", np.random.rand(2).astype(np.float32))
model.set_initializer("mul2_param", np.random.rand(2).astype(np.float32))
model.set_initializer("mul3_param", np.random.rand(2).astype(np.float32))
new_model = model.transform(MoveAddPastMul())
inp_dict = {"top_in": np.random.rand(2).astype(np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Mul"
assert new_model.graph.node[1].op_type == "Add"
assert new_model.graph.node[2].op_type == "Mul"
assert new_model.graph.node[3].op_type == "Add"
assert new_model.graph.node[4].op_type == "Mul"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment