Skip to content
Snippets Groups Projects
Commit f1aa0e3b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add forked tranpose testcase to test_absorb_opposite_transposes

parent a8bb9754
No related branches found
No related tags found
No related merge requests found
......@@ -61,16 +61,23 @@ def test_absorb_opposite_transposes():
t2_out = Transpose<perm=[0,2,3,1]>(add1_out)
t3_out = Transpose<perm=[0,3,1,2]>(t2_out)
add2_out = Add(t1_out, t3_out)
out0 = Mul(add2_out, mul0_param)
t4_out = Transpose<perm=[0,2,3,1]>(add2_out)
t5_out = Transpose<perm=[0,3,1,2]>(t4_out)
t6_out = Transpose<perm=[0,3,1,2]>(t4_out)
m0_out = Mul(t5_out, mul0_param)
m1_out = Mul(t6_out, mul0_param)
out0 = Mul(m0_out, m1_out)
}}
"""
model = oprs.parse_model(input)
model = ModelWrapper(model)
model = model.transform(InferShapes())
model.save("dbg.onnx")
new_model = model.transform(AbsorbConsecutiveTransposes())
new_model = new_model.transform(InferShapes())
new_model.save("newdbg.onnx")
inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)}
assert ox.compare_execution(model, model, inp_dict)
assert len(new_model.graph.node) == 4
assert len(new_model.graph.node) == 6
for n in new_model.graph.node:
assert new_model.graph.node[0].op_type != "Transpose"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment