From ca8b97d682e89ce9829952497249e924182dcdc7 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 28 Oct 2021 22:06:53 +0200
Subject: [PATCH] [Concat] shape and stitching bugfixes

---
 src/finn/custom_op/fpgadataflow/concat.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/src/finn/custom_op/fpgadataflow/concat.py b/src/finn/custom_op/fpgadataflow/concat.py
index 482ec4eeb..431c06af1 100644
--- a/src/finn/custom_op/fpgadataflow/concat.py
+++ b/src/finn/custom_op/fpgadataflow/concat.py
@@ -31,6 +31,7 @@ import os
 
 from finn.core.datatype import DataType
 from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp
+from finn.util.basic import roundup_to_integer_multiple
 from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
 
 
@@ -61,7 +62,7 @@ class StreamingConcat(HLSCustomOp):
 
     def get_total_elems(self):
         elems_per_stream = self.get_nodeattr("ElemsPerStream")
-        return np.sum(elems_per_stream)
+        return int(np.sum(elems_per_stream))
 
     def get_normal_input_shape(self, ind=0):
         elems_per_stream = self.get_nodeattr("ElemsPerStream")
@@ -346,6 +347,10 @@ class StreamingConcat(HLSCustomOp):
             "#pragma HLS INTERFACE ap_ctrl_none port=return"
         )
 
+    def get_instream_width_padded(self, ind=0):
+        in_width = self.get_instream_width(ind)
+        return roundup_to_integer_multiple(in_width, 8)
+
     def get_verilog_top_module_intf_names(self):
         intf_names = super().get_verilog_top_module_intf_names()
         n_inputs = self.get_n_inputs()
-- 
GitLab