From f1aa0e3b5db76829ec07f8d482b5f25656fbb58a Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 20 Jul 2022 11:11:30 +0200
Subject: [PATCH] [Test] add forked tranpose testcase to
 test_absorb_opposite_transposes

---
 .../streamline/test_absorb_opposite_transposes.py     | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/tests/transformation/streamline/test_absorb_opposite_transposes.py b/tests/transformation/streamline/test_absorb_opposite_transposes.py
index d4d26f3cf..88cbd5657 100644
--- a/tests/transformation/streamline/test_absorb_opposite_transposes.py
+++ b/tests/transformation/streamline/test_absorb_opposite_transposes.py
@@ -61,16 +61,23 @@ def test_absorb_opposite_transposes():
         t2_out = Transpose<perm=[0,2,3,1]>(add1_out)
         t3_out = Transpose<perm=[0,3,1,2]>(t2_out)
         add2_out = Add(t1_out, t3_out)
-        out0 = Mul(add2_out, mul0_param)
+        t4_out = Transpose<perm=[0,2,3,1]>(add2_out)
+        t5_out = Transpose<perm=[0,3,1,2]>(t4_out)
+        t6_out = Transpose<perm=[0,3,1,2]>(t4_out)
+        m0_out = Mul(t5_out, mul0_param)
+        m1_out = Mul(t6_out, mul0_param)
+        out0 = Mul(m0_out, m1_out)
     }}
     """
     model = oprs.parse_model(input)
     model = ModelWrapper(model)
     model = model.transform(InferShapes())
+    model.save("dbg.onnx")
     new_model = model.transform(AbsorbConsecutiveTransposes())
     new_model = new_model.transform(InferShapes())
+    new_model.save("newdbg.onnx")
     inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)}
     assert ox.compare_execution(model, model, inp_dict)
-    assert len(new_model.graph.node) == 4
+    assert len(new_model.graph.node) == 6
     for n in new_model.graph.node:
         assert new_model.graph.node[0].op_type != "Transpose"
-- 
GitLab