diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index db9fc05bd694375ad81854eb5cd16cd0af9f3262..7408b119cb6f694ba5dcd056e25a1ec49764f5df 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -987,24 +987,17 @@ class StreamingFCLayer_Batch(HLSCustomOp): self.code_gen_dict["$LAYER_NAME$"] = [ "{}_{}".format(self.onnx_node.name, self.onnx_node.name) ] - # make instream width a multiple of 8 for axi interface - in_width = self.get_instream_width() - if in_width % 8 != 0: - in_width = math.floor(in_width / 8) + 8 + # make instream width a multiple of 8 for AXI stream interface + in_width = roundup_to_integer_multiple(self.get_instream_width(), 8) self.code_gen_dict["$IN_RANGE$"] = ["[{}:0]".format(in_width - 1)] self.code_gen_dict["$OUT_RANGE$"] = [ "[{}:0]".format(self.get_outstream_width() - 1) ] - # make weight stream width a multiple of 8 for axi interface - weight_width = self.get_weightstream_width() - if weight_width % 8 != 0: - weight_width = math.floor(weight_width / 8) + 8 + # make weight stream width a multiple of 8 for AXI stream interface + weight_width = roundup_to_integer_multiple(self.get_weightstream_width(), 8) self.code_gen_dict["$WEIGHT_RANGE$"] = ["[{}:0]".format(weight_width - 1)] self.code_gen_dict["$WEIGHT_WIDTH$"] = [str(weight_width)] - mw = self.get_nodeattr("MW") - mh = self.get_nodeattr("MH") - depth = int(mw * mh) - self.code_gen_dict["$WEIGHT_DEPTH$"] = [str(depth)] + self.code_gen_dict["$WSTREAM_DEPTH$"] = [str(self.calc_wmem())] self.code_gen_dict["$MEM_DEPTH$"] = [ str(roundup_to_integer_multiple(self.calc_wmem(), 1024)) ] diff --git a/src/finn/custom_op/fpgadataflow/templates.py b/src/finn/custom_op/fpgadataflow/templates.py index 3e5205d9e8fc1abd5938f8a3dc4df489f81b9eb7..f6114ea028817b4248b37572067ed80de8364712 100644 --- a/src/finn/custom_op/fpgadataflow/templates.py +++ b/src/finn/custom_op/fpgadataflow/templates.py @@ -196,7 +196,7 @@ memstream .STRM5_WIDTH($WEIGHT_WIDTH$), //depths per stream -.STRM0_DEPTH($WEIGHT_DEPTH$), +.STRM0_DEPTH($WSTREAM_DEPTH$), .STRM1_DEPTH(1), .STRM2_DEPTH(1), .STRM3_DEPTH(1),