From 64779fdcc17c14942e504ce22925e4837ff0d1db Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Mon, 18 May 2020 15:39:58 +0100
Subject: [PATCH] [TEST] Add test to verify that CollapseRepeatedOp only
 operates on linear segments. Add check to previous test to verify that the
 transformation is doing the required job

---
 .../test_collapse_repeated_op.py              | 58 +++++++++++++++++++
 1 file changed, 58 insertions(+)

diff --git a/tests/transformation/test_collapse_repeated_op.py b/tests/transformation/test_collapse_repeated_op.py
index 01d932ece..b74d868f9 100644
--- a/tests/transformation/test_collapse_repeated_op.py
+++ b/tests/transformation/test_collapse_repeated_op.py
@@ -34,6 +34,7 @@ import finn.core.onnx_exec as ox
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul
+import pytest
 
 
 def test_collapse_repeated_op():
@@ -67,3 +68,60 @@ def test_collapse_repeated_op():
     new_model = new_model.transform(CollapseRepeatedMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert len(new_model.graph.node) == 2
+    assert new_model.graph.node[0].op_type == "Add"
+    assert new_model.graph.node[1].op_type == "Mul"
+
+
+@pytest.mark.parametrize(
+    "test_args", [("Add", CollapseRepeatedAdd()), ("Mul", CollapseRepeatedMul())],
+)
+def test_collapse_repeated_only_if_linear(test_args):
+    scalar_op = test_args[0]
+    transf_fxn = test_args[1]
+
+    input_shape = [4, 4]
+    output_shape = input_shape
+
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
+
+    value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, [1])]
+
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node(scalar_op, ["top_in", "p2"], ["t1"]),
+                oh.make_node(scalar_op, ["t1", "p1"], ["t2"]),
+                oh.make_node(scalar_op, ["t2", "p3"], ["t3"]),
+                oh.make_node(scalar_op, ["t2", "p4"], ["t4"]),
+                oh.make_node(scalar_op, ["t3", "t4"], ["t5"]),
+                oh.make_node(scalar_op, ["t5", "p5"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p2", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p3", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p4", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p5", *np.random.rand(1).astype(np.float32))
+
+    # Transform
+    new_model = model.transform(transf_fxn)
+
+    # Test
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert len(new_model.graph.node) == 5
-- 
GitLab