diff --git a/src/finn/custom_op/fpgadataflow/fmpadding_rtl.py b/src/finn/custom_op/fpgadataflow/fmpadding_rtl.py index df6f40ab6f5be6bce79f22d320f12b4625db1efe..d0302540bd55a427bfe8ab56de7adc8c9d37cb5f 100644 --- a/src/finn/custom_op/fpgadataflow/fmpadding_rtl.py +++ b/src/finn/custom_op/fpgadataflow/fmpadding_rtl.py @@ -101,20 +101,20 @@ class FMPadding_rtl(HLSCustomOp): exp_cycles = (channels / simd) * batch_size * odim_h * odim_w return int(exp_cycles) - def get_normal_input_shape(self): + def get_normal_input_shape(self, ind=0): idim_h, idim_w = self.get_nodeattr("ImgDim") num_ch = self.get_nodeattr("NumChannels") ishape = (1, idim_h, idim_w, num_ch) return ishape - def get_normal_output_shape(self): + def get_normal_output_shape(self, ind=0): odim_h, odim_w = self.get_padded_odim() num_ch = self.get_nodeattr("NumChannels") oshape = (1, odim_h, odim_w, num_ch) return oshape - def get_folded_input_shape(self): + def get_folded_input_shape(self, ind=0): normal_ishape = list(self.get_normal_input_shape()) ifm_ch = self.get_nodeattr("NumChannels") simd = self.get_nodeattr("SIMD") @@ -123,7 +123,7 @@ class FMPadding_rtl(HLSCustomOp): folded_ishape = normal_ishape[:-1] + [fold, simd] return tuple(folded_ishape) - def get_folded_output_shape(self): + def get_folded_output_shape(self, ind=0): normal_oshape = list(self.get_normal_output_shape()) ifm_ch = self.get_nodeattr("NumChannels") simd = self.get_nodeattr("SIMD") @@ -155,7 +155,7 @@ class FMPadding_rtl(HLSCustomOp): def verify_node(self): pass - def get_input_datatype(self): + def get_input_datatype(self, ind=0): """Returns FINN DataType of input.""" ret = DataType[self.get_nodeattr("inputDataType")] # the hlslib op always pads with zeros, so ensure that the DataType @@ -163,16 +163,16 @@ class FMPadding_rtl(HLSCustomOp): assert ret.allowed(0), "FMPadding_Batch DataType must support zero" return ret - def get_output_datatype(self): + def get_output_datatype(self, ind=0): """Returns FINN DataType of output. (Same as input datatype)""" return self.get_input_datatype() - def get_instream_width(self): + def get_instream_width(self, ind=0): ibits = self.get_input_datatype().bitwidth() simd = self.get_nodeattr("SIMD") return ibits * simd - def get_outstream_width(self): + def get_outstream_width(self, ind=0): obits = self.get_output_datatype().bitwidth() simd = self.get_nodeattr("SIMD") return obits * simd