diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index ce21ad38c842bf967e96c06ad39525d4b7690297..be9b51e6a7b1b3e255cd2ee8baf10937b95f8665 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -208,10 +208,21 @@ class StreamingFCLayer_Batch(HLSCustomOp): o_bits = self.get_output_datatype().bitwidth() return o_bits * self.get_nodeattr("PE") - def get_number_output_values(self): + def get_folded_input_shape(self): + mw = self.get_nodeattr("MW") + simd = self.get_nodeattr("SIMD") + sf = mw // simd + return (1, sf, simd) + + def get_folded_output_shape(self): mh = self.get_nodeattr("MH") pe = self.get_nodeattr("PE") - return mh // pe + nf = mh // pe + return (1, nf, pe) + + def get_number_output_values(self): + nf = self.get_folded_output_shape()[1] + return nf def get_template_param_values(self): ret = dict()