Skip to content
Snippets Groups Projects
Commit f29a92fc authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add test_move_scalar_add_past_matmul

parent efe1a7a3
No related branches found
No related tags found
No related merge requests found
......@@ -39,3 +39,36 @@ def test_move_scalar_mul_past_matmul():
assert new_model.graph.node[0].op_type == "MatMul"
assert new_model.graph.node[1].op_type == "Mul"
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
def test_move_scalar_add_past_matmul():
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2])
add_param = oh.make_tensor_value_info("add_param", TensorProto.FLOAT, [1, 1])
matmul_param = oh.make_tensor_value_info("matmul_param", TensorProto.FLOAT, [2, 2])
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 2])
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=[add_param, matmul_param],
nodes=[
oh.make_node("Add", ["top_in", "add_param"], ["middle"]),
oh.make_node("MatMul", ["middle", "matmul_param"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform_single(si.infer_shapes)
model.set_initializer("add_param", np.asarray([[3]], dtype=np.float32))
model.set_initializer(
"matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32)
)
new_model = model.transform_repeated(tx.move_scalar_add_past_matmul)
inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)}
out_orig = ox.execute_onnx(model, inp_dict)["top_out"]
out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"]
assert np.isclose(out_orig, out_transformed).all()
assert new_model.graph.node[0].op_type == "MatMul"
assert new_model.graph.node[1].op_type == "Add"
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment