diff --git a/tests/transformation/streamline/test_absorb_opposite_transposes.py b/tests/transformation/streamline/test_absorb_opposite_transposes.py index d4d26f3cf8bd35982960eaffcb97ee5788557e5a..88cbd5657e2ae6c0946e59186c25d935595ad2ff 100644 --- a/tests/transformation/streamline/test_absorb_opposite_transposes.py +++ b/tests/transformation/streamline/test_absorb_opposite_transposes.py @@ -61,16 +61,23 @@ def test_absorb_opposite_transposes(): t2_out = Transpose<perm=[0,2,3,1]>(add1_out) t3_out = Transpose<perm=[0,3,1,2]>(t2_out) add2_out = Add(t1_out, t3_out) - out0 = Mul(add2_out, mul0_param) + t4_out = Transpose<perm=[0,2,3,1]>(add2_out) + t5_out = Transpose<perm=[0,3,1,2]>(t4_out) + t6_out = Transpose<perm=[0,3,1,2]>(t4_out) + m0_out = Mul(t5_out, mul0_param) + m1_out = Mul(t6_out, mul0_param) + out0 = Mul(m0_out, m1_out) }} """ model = oprs.parse_model(input) model = ModelWrapper(model) model = model.transform(InferShapes()) + model.save("dbg.onnx") new_model = model.transform(AbsorbConsecutiveTransposes()) new_model = new_model.transform(InferShapes()) + new_model.save("newdbg.onnx") inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)} assert ox.compare_execution(model, model, inp_dict) - assert len(new_model.graph.node) == 4 + assert len(new_model.graph.node) == 6 for n in new_model.graph.node: assert new_model.graph.node[0].op_type != "Transpose"