diff --git a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py index 83bc19030ebba66907e08c5b1e52d7c0ff9207a6..7334c913b6f85cad4835b6e65eb14c488432af6b 100644 --- a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py @@ -65,7 +65,12 @@ class StreamingMaxPool_Batch(HLSCustomOp): return ishape def get_folded_input_shape(self): - return self.get_normal_input_shape() + # even though there is no folding in the current hlslib op, + # insert a time multiplexing axis to remain compatible with the + # shapes produced by the rest of the dataflow pipeline + ret = list(self.get_normal_input_shape()) + ret.insert(-1, 1) + return tuple(ret) def get_normal_output_shape(self): k = self.get_nodeattr("PoolDim") @@ -79,9 +84,12 @@ class StreamingMaxPool_Batch(HLSCustomOp): return oshape def get_folded_output_shape(self): - # no folding for StreamingMaxPool - oshape = self.get_normal_output_shape() - return oshape + # even though there is no folding in the current hlslib op, + # insert a time multiplexing axis to remain compatible with the + # shapes produced by the rest of the dataflow pipeline + ret = list(self.get_normal_output_shape()) + ret.insert(-1, 1) + return tuple(ret) def get_number_output_values(self): folded_oshape = self.get_folded_output_shape()