Skip to content
Snippets Groups Projects
Commit 7b970fa8 authored by mmrahorovic's avatar mmrahorovic
Browse files

[custom_op]: corrected maxpool in/out stream width and shapes

parent 69c67c58
No related branches found
No related tags found
No related merge requests found
......@@ -82,9 +82,6 @@ class StreamingMaxPool_Batch(HLSCustomOp):
return ishape
def get_folded_input_shape(self):
# even though there is no folding in the current hlslib op,
# insert a time multiplexing axis to remain compatible with the
# shapes produced by the rest of the dataflow pipeline
ifm_dim_h, ifm_dim_w = self.get_nodeattr("ImgDim")
ifm_ch = self.get_nodeattr("NumChannels")
pe = self.get_nodeattr("PE")
......@@ -92,7 +89,7 @@ class StreamingMaxPool_Batch(HLSCustomOp):
if self.is_1d():
folded_ishape = (1, ifm_dim_h, ifm_dim_w, nf, pe)
else:
folded_ishape = (1, ifm_dim_h, ifm_dim_w, ifm_ch, 1)
folded_ishape = (1, ifm_dim_h, ifm_dim_w, 1, ifm_ch)
return folded_ishape
def get_normal_output_shape(self):
......@@ -106,14 +103,8 @@ class StreamingMaxPool_Batch(HLSCustomOp):
assert (
ifm_dim_w % k_w == 0
), "StreamingMaxPool needs ImgDim_w % PoolDim_w == 0"
if ifm_dim_h % k_h == 0:
ofm_dim_h = int(ifm_dim_h / k_h)
else:
ofm_dim_h = int(np.ceil(ifm_dim_h / k_h))
if ifm_dim_w % k_w == 0:
ofm_dim_w = int(ifm_dim_w / k_w)
else:
ofm_dim_w = int(np.ceil(ifm_dim_w / k_w))
ofm_dim_h = int(np.floor(ifm_dim_h / k_w))
ofm_dim_w = int(np.floor(ifm_dim_w / k_w))
oshape = (1, ofm_dim_h, ofm_dim_w, ifm_ch)
return oshape
......@@ -129,7 +120,7 @@ class StreamingMaxPool_Batch(HLSCustomOp):
ret[-1] = nf
ret.append(pe)
else:
ret.append(1)
ret.insert(-1, 1)
return tuple(ret)
def get_number_output_values(self):
......@@ -262,24 +253,24 @@ class StreamingMaxPool_Batch(HLSCustomOp):
if self.is_1d():
raise Exception("Binary 1d MaxPool not implemented on HLS backend")
else:
op = "StreamingMaxPool_Batch"
op = "StreamingMaxPool"
self.code_gen_dict["$DOCOMPUTE$"] = [
"%s<ImgDim, PoolDim, NumChannels>(in0, out, numReps);" % (op)
"%s<ImgDim, PoolDim, NumChannels>(in0, out);" % (op)
]
else:
dtype = self.get_input_datatype()
dtype_hls = dtype.get_hls_datatype_str()
minval_str = str(int(dtype.min()))
if self.is_1d():
op = "StreamingMaxPool_Precision_Batch_1d"
op = "StreamingMaxPool_Precision_1d"
self.code_gen_dict["$DOCOMPUTE$"] = [
"%s<ImgDim, PoolDim, NumChannels, PE, %s, %s>(in0, out, numReps);"
"%s<ImgDim, PoolDim, NumChannels, PE, %s, %s>(in0, out);"
% (op, dtype_hls, minval_str)
]
else:
op = "StreamingMaxPool_Precision_Batch"
op = "StreamingMaxPool_Precision"
self.code_gen_dict["$DOCOMPUTE$"] = [
"%s<ImgDim, PoolDim, NumChannels, %s, %s>(in0, out, numReps);"
"%s<ImgDim, PoolDim, NumChannels, %s, %s>(in0, out);"
% (op, dtype_hls, minval_str)
]
......@@ -365,10 +356,8 @@ class StreamingMaxPool_Batch(HLSCustomOp):
export_idt = DataType["BINARY"]
else:
export_idt = self.get_input_datatype()
# reshape input into folded form
inp = inp.reshape(folded_ishape)
# make copy before saving array
reshaped_input = inp.copy()
reshaped_input = inp.reshape(folded_ishape)
np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input)
if mode == "cppsim":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment