Skip to content
Snippets Groups Projects
Commit eb2652af authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Pool] change Pool_Batch to use kernel size as list, not scalar

parent 18d17bfd
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,7 @@ class Pool_Batch(HLSCustomOp):
"""Class that corresponds to finn-hlslib Pool_batch function.
Requires ConvolutionInputGenerator(depthwise == 1) to format its input
Input shape (BatchSize,OutImgDim,OutImgDim,KernelSize^2*Channels)
Input shape (BatchSize,OutImgDim,OutImgDim,TotalKernelSize*Channels)
Output shape (BatchSize,OutImgDim,OutImgDim,Channels)
Notes:
......@@ -56,7 +56,7 @@ class Pool_Batch(HLSCustomOp):
my_attrs = {
"Channels": ("i", True, 0),
"PE": ("i", True, 1),
"KernelSize": ("i", True, 0),
"KernelSize": ("ints", True, []),
# Function:
# - MaxPool
# - QuantAvgPool
......@@ -103,7 +103,8 @@ class Pool_Batch(HLSCustomOp):
odim = self.get_nodeattr("OutImgDim")
batch_size = self.get_nodeattr("BatchSize")
k = self.get_nodeattr("KernelSize")
ishape = (batch_size, odim, odim, k * k * ifm_ch)
k_prod = int(np.prod(k))
ishape = (batch_size, odim, odim, k_prod * ifm_ch)
return ishape
def get_folded_input_shape(self):
......@@ -140,9 +141,10 @@ class Pool_Batch(HLSCustomOp):
ifm_ch = self.get_nodeattr("Channels")
pe = self.get_nodeattr("PE")
k = self.get_nodeattr("KernelSize")
k_prod = int(np.prod(k))
odim = self.get_nodeattr("OutImgDim")
batch_size = self.get_nodeattr("BatchSize")
exp_cycles = ((ifm_ch * k * k) / pe) * odim * odim * batch_size
exp_cycles = ((ifm_ch * k_prod) / pe) * odim * odim * batch_size
return int(exp_cycles)
def get_instream_width(self):
......@@ -211,7 +213,8 @@ class Pool_Batch(HLSCustomOp):
self.code_gen_dict["$DEFINES$"] += ["#define PE {}".format(pe)]
k = self.get_nodeattr("KernelSize")
self.code_gen_dict["$DEFINES$"] += ["#define KernelSize {}".format(k * k)]
k_prod = int(np.prod(k))
self.code_gen_dict["$DEFINES$"] += ["#define KernelSize {}".format(k_prod)]
odim = self.get_nodeattr("OutImgDim")
self.code_gen_dict["$DEFINES$"] += ["#define OFMDim {}".format(odim)]
......
......@@ -534,7 +534,7 @@ class InferPool_Batch(Transformation):
OutputDataType=odt.name,
Channels=ifm_ch,
PE=ifm_ch,
KernelSize=k,
KernelSize=[k, k],
Function=pool_fxn,
OutImgDim=ofm_dim,
AccumBits=accum_bits,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment