From 255cec172912363beec46e8dba2486b2c2fc56be Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Wed, 20 May 2020 15:37:19 +0100
Subject: [PATCH] [Test] Add test for MoveLinearPastFork specialization of
 MoveOpPastFork

---
 tests/transformation/test_move_past_fork.py | 79 +++++++++++++++++++++
 1 file changed, 79 insertions(+)
 create mode 100644 tests/transformation/test_move_past_fork.py

diff --git a/tests/transformation/test_move_past_fork.py b/tests/transformation/test_move_past_fork.py
new file mode 100644
index 000000000..f3d37bd60
--- /dev/null
+++ b/tests/transformation/test_move_past_fork.py
@@ -0,0 +1,79 @@
+from onnx import TensorProto, helper
+import numpy as np
+
+import finn.core.onnx_exec as oxe
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.streamline.reorder import MoveLinearPastFork
+from finn.transformation.infer_shapes import InferShapes
+
+import pytest
+
+
+@pytest.mark.parametrize("ch", [64, 1])
+# ifmdim
+@pytest.mark.parametrize("ifmdim", [-1, 7])
+def test_move_past_fork(ch, ifmdim):
+    # generate test vectors of correct shape
+    if ifmdim == -1:
+        input_shape = (1, ch)
+    else:
+        input_shape = (1, ch, ifmdim, ifmdim)
+
+    top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
+
+    num_of_params = 8
+    value_info = []
+    for i in range(num_of_params):
+        value_info += [
+            helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
+        ]
+
+    add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"])
+    mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"])
+    add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"])
+    mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"])
+    modelproto = helper.make_model(
+        helper.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                # fork1
+                add_1_to_move,
+                helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
+                helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
+                helper.make_node("Add", ["t2", "t3"], ["t4"]),
+                helper.make_node("Add", ["t4", "p3"], ["t5"]),
+                # fork2
+                mul_1_to_move,
+                add_2_to_move,
+                helper.make_node("Add", ["fork2", "p6"], ["t7"]),
+                helper.make_node("Add", ["t6", "t7"], ["t8"]),
+                # empty branches: do nothing
+                mul_1_not_to_move,
+                helper.make_node("Add", ["fork3", "fork3"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    for i in range(num_of_params):
+        model.set_initializer(
+            "p" + str(i), np.random.rand(*input_shape).astype(np.float32)
+        )
+
+    # Transform
+    new_model = model.transform(MoveLinearPastFork())
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    # Test
+    assert oxe.compare_execution(model, new_model, inp_dict)
+    assert not new_model.is_fork_node(add_1_to_move)
+    assert not new_model.is_fork_node(mul_1_to_move)
+    assert not new_model.is_fork_node(add_2_to_move)
+    assert new_model.is_fork_node(mul_1_not_to_move)
+    assert len(new_model.graph.node) == 14
-- 
GitLab