From f1aa0e3b5db76829ec07f8d482b5f25656fbb58a Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Wed, 20 Jul 2022 11:11:30 +0200 Subject: [PATCH] [Test] add forked tranpose testcase to test_absorb_opposite_transposes --- .../streamline/test_absorb_opposite_transposes.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/transformation/streamline/test_absorb_opposite_transposes.py b/tests/transformation/streamline/test_absorb_opposite_transposes.py index d4d26f3cf..88cbd5657 100644 --- a/tests/transformation/streamline/test_absorb_opposite_transposes.py +++ b/tests/transformation/streamline/test_absorb_opposite_transposes.py @@ -61,16 +61,23 @@ def test_absorb_opposite_transposes(): t2_out = Transpose<perm=[0,2,3,1]>(add1_out) t3_out = Transpose<perm=[0,3,1,2]>(t2_out) add2_out = Add(t1_out, t3_out) - out0 = Mul(add2_out, mul0_param) + t4_out = Transpose<perm=[0,2,3,1]>(add2_out) + t5_out = Transpose<perm=[0,3,1,2]>(t4_out) + t6_out = Transpose<perm=[0,3,1,2]>(t4_out) + m0_out = Mul(t5_out, mul0_param) + m1_out = Mul(t6_out, mul0_param) + out0 = Mul(m0_out, m1_out) }} """ model = oprs.parse_model(input) model = ModelWrapper(model) model = model.transform(InferShapes()) + model.save("dbg.onnx") new_model = model.transform(AbsorbConsecutiveTransposes()) new_model = new_model.transform(InferShapes()) + new_model.save("newdbg.onnx") inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)} assert ox.compare_execution(model, model, inp_dict) - assert len(new_model.graph.node) == 4 + assert len(new_model.graph.node) == 6 for n in new_model.graph.node: assert new_model.graph.node[0].op_type != "Transpose" -- GitLab