From 991c1e3e60b9b6b98a317241da6440fd943104fb Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 4 May 2020 23:41:01 +0100
Subject: [PATCH] [HLSMaxPool] add dummy time mux dimension to StreamingMaxPool

---
 .../fpgadataflow/streamingmaxpool_batch.py       | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
index 83bc19030..7334c913b 100644
--- a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
@@ -65,7 +65,12 @@ class StreamingMaxPool_Batch(HLSCustomOp):
         return ishape
 
     def get_folded_input_shape(self):
-        return self.get_normal_input_shape()
+        # even though there is no folding in the current hlslib op,
+        # insert a time multiplexing axis to remain compatible with the
+        # shapes produced by the rest of the dataflow pipeline
+        ret = list(self.get_normal_input_shape())
+        ret.insert(-1, 1)
+        return tuple(ret)
 
     def get_normal_output_shape(self):
         k = self.get_nodeattr("PoolDim")
@@ -79,9 +84,12 @@ class StreamingMaxPool_Batch(HLSCustomOp):
         return oshape
 
     def get_folded_output_shape(self):
-        # no folding for StreamingMaxPool
-        oshape = self.get_normal_output_shape()
-        return oshape
+        # even though there is no folding in the current hlslib op,
+        # insert a time multiplexing axis to remain compatible with the
+        # shapes produced by the rest of the dataflow pipeline
+        ret = list(self.get_normal_output_shape())
+        ret.insert(-1, 1)
+        return tuple(ret)
 
     def get_number_output_values(self):
         folded_oshape = self.get_folded_output_shape()
-- 
GitLab