diff --git a/src/finn/custom_op/fpgadataflow/concat.py b/src/finn/custom_op/fpgadataflow/concat.py index 482ec4eeb0fdd018628e6a4ceee2962a157e7014..431c06af14d6cc112070a4425b1a2d823a2e8f48 100644 --- a/src/finn/custom_op/fpgadataflow/concat.py +++ b/src/finn/custom_op/fpgadataflow/concat.py @@ -31,6 +31,7 @@ import os from finn.core.datatype import DataType from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp +from finn.util.basic import roundup_to_integer_multiple from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy @@ -61,7 +62,7 @@ class StreamingConcat(HLSCustomOp): def get_total_elems(self): elems_per_stream = self.get_nodeattr("ElemsPerStream") - return np.sum(elems_per_stream) + return int(np.sum(elems_per_stream)) def get_normal_input_shape(self, ind=0): elems_per_stream = self.get_nodeattr("ElemsPerStream") @@ -346,6 +347,10 @@ class StreamingConcat(HLSCustomOp): "#pragma HLS INTERFACE ap_ctrl_none port=return" ) + def get_instream_width_padded(self, ind=0): + in_width = self.get_instream_width(ind) + return roundup_to_integer_multiple(in_width, 8) + def get_verilog_top_module_intf_names(self): intf_names = super().get_verilog_top_module_intf_names() n_inputs = self.get_n_inputs()