From eb2652afa085fc995568ca33095c3909b17fa266 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Tue, 23 Nov 2021 14:25:41 +0100
Subject: [PATCH] [Pool] change Pool_Batch to use kernel size as list, not
 scalar

---
 src/finn/custom_op/fpgadataflow/pool_batch.py       | 13 ++++++++-----
 .../fpgadataflow/convert_to_hls_layers.py           |  2 +-
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/pool_batch.py b/src/finn/custom_op/fpgadataflow/pool_batch.py
index 92679e306..bcf25c648 100644
--- a/src/finn/custom_op/fpgadataflow/pool_batch.py
+++ b/src/finn/custom_op/fpgadataflow/pool_batch.py
@@ -38,7 +38,7 @@ class Pool_Batch(HLSCustomOp):
     """Class that corresponds to finn-hlslib Pool_batch function.
     Requires ConvolutionInputGenerator(depthwise == 1) to format its input
 
-    Input shape (BatchSize,OutImgDim,OutImgDim,KernelSize^2*Channels)
+    Input shape (BatchSize,OutImgDim,OutImgDim,TotalKernelSize*Channels)
     Output shape (BatchSize,OutImgDim,OutImgDim,Channels)
 
     Notes:
@@ -56,7 +56,7 @@ class Pool_Batch(HLSCustomOp):
         my_attrs = {
             "Channels": ("i", True, 0),
             "PE": ("i", True, 1),
-            "KernelSize": ("i", True, 0),
+            "KernelSize": ("ints", True, []),
             # Function:
             #  - MaxPool
             #  - QuantAvgPool
@@ -103,7 +103,8 @@ class Pool_Batch(HLSCustomOp):
         odim = self.get_nodeattr("OutImgDim")
         batch_size = self.get_nodeattr("BatchSize")
         k = self.get_nodeattr("KernelSize")
-        ishape = (batch_size, odim, odim, k * k * ifm_ch)
+        k_prod = int(np.prod(k))
+        ishape = (batch_size, odim, odim, k_prod * ifm_ch)
         return ishape
 
     def get_folded_input_shape(self):
@@ -140,9 +141,10 @@ class Pool_Batch(HLSCustomOp):
         ifm_ch = self.get_nodeattr("Channels")
         pe = self.get_nodeattr("PE")
         k = self.get_nodeattr("KernelSize")
+        k_prod = int(np.prod(k))
         odim = self.get_nodeattr("OutImgDim")
         batch_size = self.get_nodeattr("BatchSize")
-        exp_cycles = ((ifm_ch * k * k) / pe) * odim * odim * batch_size
+        exp_cycles = ((ifm_ch * k_prod) / pe) * odim * odim * batch_size
         return int(exp_cycles)
 
     def get_instream_width(self):
@@ -211,7 +213,8 @@ class Pool_Batch(HLSCustomOp):
         self.code_gen_dict["$DEFINES$"] += ["#define PE {}".format(pe)]
 
         k = self.get_nodeattr("KernelSize")
-        self.code_gen_dict["$DEFINES$"] += ["#define KernelSize {}".format(k * k)]
+        k_prod = int(np.prod(k))
+        self.code_gen_dict["$DEFINES$"] += ["#define KernelSize {}".format(k_prod)]
 
         odim = self.get_nodeattr("OutImgDim")
         self.code_gen_dict["$DEFINES$"] += ["#define OFMDim {}".format(odim)]
diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 60ae3fc75..54bd01e2b 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -534,7 +534,7 @@ class InferPool_Batch(Transformation):
                     OutputDataType=odt.name,
                     Channels=ifm_ch,
                     PE=ifm_ch,
-                    KernelSize=k,
+                    KernelSize=[k, k],
                     Function=pool_fxn,
                     OutImgDim=ofm_dim,
                     AccumBits=accum_bits,
-- 
GitLab