diff --git a/tests/transformation/test_move_add_past_mul.py b/tests/transformation/test_move_add_past_mul.py
index a0516d6fb2ff985fc112185ce99ad8facd841caf..163b9d310a5f12bd0b854f9aa46f53a549bf109e 100644
--- a/tests/transformation/test_move_add_past_mul.py
+++ b/tests/transformation/test_move_add_past_mul.py
@@ -60,6 +60,9 @@ def test_move_add_past_mul_single():
     new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Add"
+    assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
 
 
 def test_move_add_past_mul_multi():
@@ -92,3 +95,50 @@ def test_move_add_past_mul_multi():
     new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Mul"
+    assert new_model.graph.node[2].op_type == "Add"
+    assert new_model.graph.node[3].op_type == "Add"
+    for i in range(len(new_model.graph.node) - 1):
+        assert new_model.graph.node[i].output[0] == new_model.graph.node[i + 1].input[0]
+
+
+def test_move_add_past_mul_only_if_linear():
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2])
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2])
+
+    value_info = [oh.make_tensor_value_info("add1_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul3_param", TensorProto.FLOAT, [1])]
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node("Add", ["top_in", "add1_param"], ["t1"]),
+                oh.make_node("Mul", ["t1", "mul1_param"], ["fork"]),
+                oh.make_node("Mul", ["fork", "mul2_param"], ["t3"]),
+                oh.make_node("Add", ["t3", "fork"], ["t4"]),
+                oh.make_node("Mul", ["t4", "mul3_param"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("add1_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul1_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul2_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul3_param", np.random.rand(2).astype(np.float32))
+    new_model = model.transform(MoveAddPastMul())
+    inp_dict = {"top_in": np.random.rand(2).astype(np.float32)}
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Add"
+    assert new_model.graph.node[2].op_type == "Mul"
+    assert new_model.graph.node[3].op_type == "Add"
+    assert new_model.graph.node[4].op_type == "Mul"