From 898f2294ae2bc2beadc00a2f9ca1fdb3e29c4d7a Mon Sep 17 00:00:00 2001
From: mmrahorovic <mmrahorovic@hotmail.com>
Date: Sun, 20 Feb 2022 22:42:32 +0000
Subject: [PATCH] [custom_op]: MaxPool support for ceil mode

---
 .../fpgadataflow/streamingmaxpool_batch.py       | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
index daa8319cd..d2e1406d2 100755
--- a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
@@ -32,6 +32,7 @@ import warnings
 
 from finn.core.datatype import DataType
 from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp
+from finn.custom_op.general.maxpoolnhwc import compute_pool_output_dim
 from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
 
 
@@ -44,6 +45,7 @@ class StreamingMaxPool_Batch(HLSCustomOp):
             "PoolDim": ("ints", True, []),  # [H, W] = [Y, X]
             "NumChannels": ("i", True, 0),
             "PE": ("i", True, 0),
+            "CeilMode": ("i", False, 0),
             # FINN DataTypes for inputs/outputs
             "dataType": ("s", True, ""),
         }
@@ -96,6 +98,7 @@ class StreamingMaxPool_Batch(HLSCustomOp):
         ifm_dim_h, ifm_dim_w = self.get_nodeattr("ImgDim")
         k_h, k_w = tuple(self.get_nodeattr("PoolDim"))
         ifm_ch = self.get_nodeattr("NumChannels")
+        ceil_mode = self.get_nodeattr("CeilMode")
         if not self.is_1d():
             assert (
                 ifm_dim_h % k_h == 0
@@ -103,8 +106,8 @@ class StreamingMaxPool_Batch(HLSCustomOp):
             assert (
                 ifm_dim_w % k_w == 0
             ), "StreamingMaxPool needs ImgDim_w % PoolDim_w == 0"
-        ofm_dim_h = int(np.floor(ifm_dim_h / k_w))
-        ofm_dim_w = int(np.floor(ifm_dim_w / k_w))
+        ofm_dim_h = compute_pool_output_dim(ifm_dim_h, k_h, k_h, 0, ceil_mode)
+        ofm_dim_w = compute_pool_output_dim(ifm_dim_w, k_w, k_w, 0, ceil_mode)
         oshape = (1, ofm_dim_h, ofm_dim_w, ifm_ch)
         return oshape
 
@@ -197,15 +200,19 @@ class StreamingMaxPool_Batch(HLSCustomOp):
     def defines(self, var):
         numReps = 1
         ifm_dim, k, ifm_ch = self.get_1d_attrs_normalized()
+        ceil_mode = self.get_nodeattr("CeilMode")
+        output_size = compute_pool_output_dim(ifm_dim[1], k[1], k[1], 0, ceil_mode)
 
         if self.is_1d():
             self.code_gen_dict["$DEFINES$"] = [
                 """#define ImgDim {}\n #define PoolDim {}\n
-                #define NumChannels {}\n #define PE {}\n #define numReps {}""".format(
+                #define NumChannels {}\n #define PE {}\n #define OutputSize {}
+                \n #define numReps {}""".format(
                     ifm_dim[1],
                     k[1],
                     self.get_nodeattr("NumChannels"),
                     self.get_nodeattr("PE"),
+                    output_size,
                     numReps,
                 )
             ]
@@ -264,7 +271,8 @@ class StreamingMaxPool_Batch(HLSCustomOp):
             if self.is_1d():
                 op = "StreamingMaxPool_Precision_1d"
                 self.code_gen_dict["$DOCOMPUTE$"] = [
-                    "%s<ImgDim, PoolDim, NumChannels, PE, %s, %s>(in0, out);"
+                    """%s<ImgDim, PoolDim, NumChannels, PE,
+                     OutputSize, %s, %s>(in0, out);"""
                     % (op, dtype_hls, minval_str)
                 ]
             else:
-- 
GitLab