From bf2b6291194f713b01e241407f0061c66fcd0395 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 26 Nov 2021 10:47:15 +0100
Subject: [PATCH] [HLSCustomOp] fix param mismatches between hlslib and
 generated code

---
 .../custom_op/fpgadataflow/channelwise_op_batch.py    | 11 ++++-------
 .../custom_op/fpgadataflow/streamingmaxpool_batch.py  |  5 ++---
 2 files changed, 6 insertions(+), 10 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py b/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
index 4961f6148..b40389808 100644
--- a/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
+++ b/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
@@ -514,18 +514,15 @@ class ChannelwiseOp_Batch(HLSCustomOp):
         # should ImgDim be defined or just filled in here like we do now?
         ishape = self.get_folded_input_shape()
         if len(ishape) == 3:
-            imgdim_h = 1
-            imgdim_w = 1
+            spatial_dim = 1
         elif len(ishape) == 5:
-            imgdim_h = ishape[1]
-            imgdim_w = ishape[2]
+            spatial_dim = ishape[1] * ishape[2]
         else:
             raise Exception("""Unexpeted input shape""")
         self.code_gen_dict["$DOCOMPUTE$"] = [
-            """Thresholding_Batch<{}, {}, NumChannels1, PE1, {}, {}>
+            """Thresholding_Batch<{}, NumChannels1, PE1, {}, {}>
             (in0, out, threshs, numReps);""".format(
-                imgdim_h,
-                imgdim_w,
+                spatial_dim,
                 tmpl_args["TSrcI"],
                 tmpl_args["TDstI"],
             )
diff --git a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
index 6012cc7cd..1e66a5c20 100644
--- a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
@@ -228,15 +228,14 @@ class StreamingMaxPool_Batch(HLSCustomOp):
             ]
         else:
             if self.is_1d():
-                # FIXME handle this for vitis_hls hlslib branch
                 op = "StreamingMaxPool_Precision_Batch_1d"
             else:
-                op = "StreamingMaxPool_Precision"
+                op = "StreamingMaxPool_Precision_Batch"
             dtype = self.get_input_datatype()
             dtype_hls = dtype.get_hls_datatype_str()
             minval_str = str(int(dtype.min()))
             self.code_gen_dict["$DOCOMPUTE$"] = [
-                "%s<ImgDim, PoolDim, NumChannels, %s, %s>(in0, out);"
+                "%s<ImgDim, PoolDim, NumChannels, %s, %s>(in0, out, numReps);"
                 % (op, dtype_hls, minval_str)
             ]
 
-- 
GitLab