diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index ce21ad38c842bf967e96c06ad39525d4b7690297..be9b51e6a7b1b3e255cd2ee8baf10937b95f8665 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -208,10 +208,21 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         o_bits = self.get_output_datatype().bitwidth()
         return o_bits * self.get_nodeattr("PE")
 
-    def get_number_output_values(self):
+    def get_folded_input_shape(self):
+        mw = self.get_nodeattr("MW")
+        simd = self.get_nodeattr("SIMD")
+        sf = mw // simd
+        return (1, sf, simd)
+
+    def get_folded_output_shape(self):
         mh = self.get_nodeattr("MH")
         pe = self.get_nodeattr("PE")
-        return mh // pe
+        nf = mh // pe
+        return (1, nf, pe)
+
+    def get_number_output_values(self):
+        nf = self.get_folded_output_shape()[1]
+        return nf
 
     def get_template_param_values(self):
         ret = dict()