Skip to content
Snippets Groups Projects
Commit 6ea013d5 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[StreamingFC] fix STREAM_DEPTH for decoupled mode

parent 8e22e2c9
No related branches found
No related tags found
No related merge requests found
......@@ -987,24 +987,17 @@ class StreamingFCLayer_Batch(HLSCustomOp):
self.code_gen_dict["$LAYER_NAME$"] = [
"{}_{}".format(self.onnx_node.name, self.onnx_node.name)
]
# make instream width a multiple of 8 for axi interface
in_width = self.get_instream_width()
if in_width % 8 != 0:
in_width = math.floor(in_width / 8) + 8
# make instream width a multiple of 8 for AXI stream interface
in_width = roundup_to_integer_multiple(self.get_instream_width(), 8)
self.code_gen_dict["$IN_RANGE$"] = ["[{}:0]".format(in_width - 1)]
self.code_gen_dict["$OUT_RANGE$"] = [
"[{}:0]".format(self.get_outstream_width() - 1)
]
# make weight stream width a multiple of 8 for axi interface
weight_width = self.get_weightstream_width()
if weight_width % 8 != 0:
weight_width = math.floor(weight_width / 8) + 8
# make weight stream width a multiple of 8 for AXI stream interface
weight_width = roundup_to_integer_multiple(self.get_weightstream_width(), 8)
self.code_gen_dict["$WEIGHT_RANGE$"] = ["[{}:0]".format(weight_width - 1)]
self.code_gen_dict["$WEIGHT_WIDTH$"] = [str(weight_width)]
mw = self.get_nodeattr("MW")
mh = self.get_nodeattr("MH")
depth = int(mw * mh)
self.code_gen_dict["$WEIGHT_DEPTH$"] = [str(depth)]
self.code_gen_dict["$WSTREAM_DEPTH$"] = [str(self.calc_wmem())]
self.code_gen_dict["$MEM_DEPTH$"] = [
str(roundup_to_integer_multiple(self.calc_wmem(), 1024))
]
......
......@@ -196,7 +196,7 @@ memstream
.STRM5_WIDTH($WEIGHT_WIDTH$),
//depths per stream
.STRM0_DEPTH($WEIGHT_DEPTH$),
.STRM0_DEPTH($WSTREAM_DEPTH$),
.STRM1_DEPTH(1),
.STRM2_DEPTH(1),
.STRM3_DEPTH(1),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment