From a81fc1e6e970bc27734f6495a0b86f3bdfa11159 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Wed, 1 Jul 2020 09:43:59 +0100
Subject: [PATCH] [Test] Add test cases for QuantAvgPool2d to
 test_convert_to_hls_pool_batch

---
 .../test_convert_to_hls_pool_batch.py         | 71 +++++++++++++++++--
 1 file changed, 65 insertions(+), 6 deletions(-)

diff --git a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
index c9f78dcea..809147be2 100644
--- a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
+++ b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
@@ -77,27 +77,70 @@ def make_single_maxpool_modelwrapper(k, stride, pad, ifm_ch, ifm_dim, ofm_dim, i
     return model
 
 
+def make_single_quantavpool_modelwrapper(k, stride, ifm_ch, ifm_dim, ofm_dim, idt, odt):
+    inp = helper.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+    )
+    outp = helper.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, ifm_ch, ofm_dim, ofm_dim]
+    )
+
+    mp_node = helper.make_node(
+        "QuantAvgPool2d",
+        ["inp"],
+        ["outp"],
+        domain="finn",
+        stride=stride,
+        kernel=k,
+        ibits=idt.bitwidth(),
+        obits=odt.bitwidth(),
+        signed=1 if idt.signed() else 0,
+        data_layout="NCHW",
+    )
+    graph = helper.make_graph(
+        nodes=[mp_node], name="mp_graph", inputs=[inp], outputs=[outp]
+    )
+
+    model = helper.make_model(graph, producer_name="mp-model")
+    model = ModelWrapper(model)
+
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype("outp", odt)
+    model = model.transform(InferShapes())
+
+    return model
+
+
 def prepare_inputs(input_tensor):
     return {"inp": input_tensor}
 
 
+# Note: QuantAvgPool2d with idt = DataType.UINT4 and odt = DataType.UINT8
+# (And in general seems to be a problem when odt.bitwidth() > idt.bitwidth())
+# passes cppsim but fails rtlsim(Verilator). Cosim with same parameters in
+# Vivado_HLS passes.
+
 # input datatype
 @pytest.mark.parametrize("idt", [DataType.UINT4, DataType.INT4])
+# output datatype
+@pytest.mark.parametrize("odt", [DataType.UINT4, DataType.INT4])
 # pool configuration:                   ( k,stride, pad, ifm_dim )
 @pytest.mark.parametrize(
-    "pool_config", [(3, 2, 0, 5), (3, 2, 1, 5), (2, 2, 0, 8), (5, 2, 2, 7)]
+    "pool_config", [(7, 7, 0, 7), (3, 2, 0, 5), (3, 2, 1, 5), (2, 2, 0, 8)]
 )
 # input channels
 @pytest.mark.parametrize("ifm_ch", [1, 4, 20])
 # number of out channel computed in parallel
-@pytest.mark.parametrize("pe", [1, 4, 20])
+@pytest.mark.parametrize("pe", [1, 4])
+# pool type
+@pytest.mark.parametrize("op_type", ["QuantAvgPool2d", "MaxPool"])
 # execution mode
 @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
-# pool type
-@pytest.mark.parametrize("op_type", ["MaxPool"])
 @pytest.mark.slow
 @pytest.mark.vivado
-def test_convert_to_hls_pool_batch(idt, pool_config, ifm_ch, pe, exec_mode, op_type):
+def test_convert_to_hls_pool_batch(
+    idt, odt, pool_config, ifm_ch, pe, op_type, exec_mode
+):
     k, stride, pad, ifm_dim = pool_config
 
     if ifm_ch % pe != 0:
@@ -113,9 +156,25 @@ def test_convert_to_hls_pool_batch(idt, pool_config, ifm_ch, pe, exec_mode, op_t
     # prepare input data
     input_dict = prepare_inputs(x)
     if op_type == "MaxPool":
+        # if idt.signed():
+        #     pytest.skip("""No support for signed input (see accu initialization
+        #         in Pool_batch HLSLIB function). Skipping""")
+
+        if idt != odt:
+            pytest.skip("Skipping Maxpool with idt != odt")
+
         model = make_single_maxpool_modelwrapper(
             k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt
         )
+    elif op_type == "QuantAvgPool2d":
+        if pad != 0:
+            pytest.skip("No padding support for QuantAvgPool2d. Skipping")
+
+        if idt.signed() != odt.signed():
+            pytest.skip("Skipping QuantAvgPool2d with idt.signed() != odt.signed()")
+        model = make_single_quantavpool_modelwrapper(
+            k, stride, ifm_ch, ifm_dim, ofm_dim, idt, odt
+        )
     else:
         assert False, "{} is not a supported op_type".format(op_type)
 
@@ -151,7 +210,7 @@ def test_convert_to_hls_pool_batch(idt, pool_config, ifm_ch, pe, exec_mode, op_t
     # execute new_model
     y_produced = oxe.execute_onnx(new_model, input_dict)["outp"]
     assert (y_produced == y_expected).all()
-    if stride != k:
+    if stride <= k:
         if pad == 0 or ifm_ch == pe:
             assert len(new_model.graph.node) == 4
         else:
-- 
GitLab