From b6a5b35e3cc21c143ba57a2584f660f82413852f Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Tue, 23 Nov 2021 14:39:54 +0100
Subject: [PATCH] [Test] support creating 1d maxpool in
 convert_to_hls_pool_batch

---
 .../test_convert_to_hls_pool_batch.py         | 31 ++++++++++++-------
 1 file changed, 20 insertions(+), 11 deletions(-)

diff --git a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
index 3efafc040..cea6a3189 100644
--- a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
+++ b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
@@ -48,22 +48,31 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.util.basic import gen_finn_dt_tensor
 
 
-def make_single_maxpool_modelwrapper(k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt):
+def make_single_maxpool_modelwrapper(
+    k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt, use_1d=False
+):
     odt = idt
-    inp = helper.make_tensor_value_info(
-        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
-    )
-    outp = helper.make_tensor_value_info(
-        "outp", TensorProto.FLOAT, [1, ifm_ch, ofm_dim, ofm_dim]
-    )
-
+    if use_1d:
+        ishape = [1, ifm_ch, 1, ifm_dim]
+        oshape = [1, ifm_ch, 1, ofm_dim]
+        kshape = [1, k]
+        pads = [0, pad, 0, pad]
+        strides = [1, stride]
+    else:
+        ishape = [1, ifm_ch, ifm_dim, ifm_dim]
+        oshape = [1, ifm_ch, ofm_dim, ofm_dim]
+        kshape = [1, k]
+        pads = [pad, pad, pad, pad]
+        strides = [stride, stride]
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
     mp_node = helper.make_node(
         "MaxPool",
         ["inp"],
         ["outp"],
-        kernel_shape=[k, k],
-        pads=[pad, pad, pad, pad],
-        strides=[stride, stride],
+        kernel_shape=kshape,
+        pads=pads,
+        strides=strides,
     )
     graph = helper.make_graph(
         nodes=[mp_node], name="mp_graph", inputs=[inp], outputs=[outp]
-- 
GitLab