Skip to content
Snippets Groups Projects
Commit c02b80a8 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[HLSCustomOp] Add support for QuantAvgPool in Pool_Batch

parent 255af213
No related branches found
No related tags found
No related merge requests found
......@@ -39,16 +39,18 @@ class Pool_Batch(HLSCustomOp):
"""Class that corresponds to finn-hlslib Pool_batch function.
Requires ConvolutionInputGenerator(depthwise == 1) to format its input
TODO: explain input shape (to reuse im2col code)
Input shape (BatchSize,OutImgDim,OutImgDim,KernelSize^2*Channels)
Output shape (BatchSize,OutImgDim,OutImgDim,Channels)
# note: the actual data layout produced by the hlslib kernels is different
# for depthwise ops.
# * depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/PE, K, K, PE)
Notes:
# The input shape was chosen to be compatible with im2col (only true when there
is not folding).
# The actual data layout produced by the hlslib kernels is different
for depthwise ops.
* depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/PE, K, K, PE)
Channels can be folded using PE (SIMD from the input perspective)
TODO: doc
"""
def get_nodeattr_types(self):
......@@ -63,7 +65,10 @@ class Pool_Batch(HLSCustomOp):
"Function": ("s", True, ""),
"OutImgDim": ("i", True, 0),
# FINN DataTypes for inputs/outputs
"dataType": ("s", True, ""),
"InputDataType": ("s", True, ""),
"OutputDataType": ("s", True, ""),
"AccumBits": ("i", False, 0),
"Size": ("i", False, 1),
"BatchSize": ("i", False, 1),
}
......@@ -72,17 +77,28 @@ class Pool_Batch(HLSCustomOp):
def get_input_datatype(self):
"""Returns FINN DataType of input."""
return DataType[self.get_nodeattr("dataType")]
return DataType[self.get_nodeattr("InputDataType")]
def get_output_datatype(self):
"""Returns FINN DataType of output."""
fxn = self.get_nodeattr("Function")
odt = DataType[self.get_nodeattr("OutputDataType")]
if fxn == "MaxPool":
# Same as input
return DataType[self.get_nodeattr("dataType")]
idt = DataType[self.get_nodeattr("InputDataType")]
assert odt == idt, "In datatype must be equal to out datatype for Maxpool"
elif fxn == "QuantAvgPool":
idt = DataType[self.get_nodeattr("InputDataType")]
assert (
idt.signed() == odt.signed()
), """QuantAvgPool: Can't mix signed
and unsigned datatypes"""
else:
raise Exception("Pool_Batch doesn't currently support " + fxn)
return odt
def get_normal_input_shape(self):
ifm_ch = self.get_nodeattr("Channels")
odim = self.get_nodeattr("OutImgDim")
......@@ -123,19 +139,14 @@ class Pool_Batch(HLSCustomOp):
def get_instream_width(self):
dt_bits = self.get_input_datatype().bitwidth()
pe = self.get_nodeattr("PE")
# ofm_ch = self.get_nodeattr("Channels")
# k = self.get_nodeattr("KernelSize")
# assert ifm_ch % pe == 0, "PE must divide input channels"
# simd = int(ifm_ch/pe)
in_width = int(dt_bits * pe)
return in_width
def get_outstream_width(self):
fxn = self.get_nodeattr("Function")
if fxn == "MaxPool":
return self.get_instream_width()
else:
raise Exception("Pool_Batch doesn't currently support " + fxn)
dt_bits = self.get_output_datatype().bitwidth()
pe = self.get_nodeattr("PE")
out_width = int(dt_bits * pe)
return out_width
def make_shape_compatible_op(self, model):
exp_ishape = self.get_normal_input_shape()
......@@ -187,7 +198,7 @@ class Pool_Batch(HLSCustomOp):
# check supported function
fnx = self.get_nodeattr("Function")
if fnx == "MaxPool":
if fnx in ["MaxPool", "QuantAvgPool"]:
info_messages.append(
"Attribute Function contains a supported pool function"
)
......@@ -251,7 +262,8 @@ class Pool_Batch(HLSCustomOp):
i_hls_dt = idt.get_hls_datatype_str()
odt = self.get_output_datatype()
o_hls_dt = odt.get_hls_datatype_str()
size = self.get_nodeattr("Size")
accum_bits = self.get_nodeattr("AccumBits")
self.code_gen_dict["$DOCOMPUTE$"] = []
fxn = self.get_nodeattr("Function")
......@@ -259,6 +271,16 @@ class Pool_Batch(HLSCustomOp):
self.code_gen_dict["$DOCOMPUTE$"] += [
"MaxPoolFunction<{},KernelSize> pool_fxn;".format(i_hls_dt)
]
elif fxn == "QuantAvgPool":
if idt.signed():
act_hls_dt = "ap_int<{}>".format(accum_bits)
else:
act_hls_dt = "ap_uint<{}>".format(accum_bits)
self.code_gen_dict["$DOCOMPUTE$"] += [
"QuantAvgPoolFunction<{},{},{}> pool_fxn;".format(
act_hls_dt, o_hls_dt, size
)
]
else:
raise Exception("Pool_Batch doesn't currently support " + fxn)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment