Skip to content
Snippets Groups Projects
Commit 7dbfaae9 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[FPGADataflow] Add support for QuantAvgPool2d in InferPool_Batch. Change...

[FPGADataflow] Add support for QuantAvgPool2d in InferPool_Batch. Change restriction from k <= stride to k < stride
parent c02b80a8
No related branches found
No related tags found
No related merge requests found
......@@ -35,6 +35,7 @@ from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
import finn.core.data_layout as DataLayout
from finn.util.basic import get_by_name
import warnings
class InferConvInpGen(Transformation):
......@@ -187,13 +188,26 @@ class InferPool_Batch(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type in ["MaxPool"]:
if n.op_type in ["MaxPool", "QuantAvgPool2d"]:
# extract pool parameters
k = get_by_name(n.attribute, "kernel_shape").ints[-1]
stride = get_by_name(n.attribute, "strides").ints[-1]
if k <= stride:
if n.op_type == "MaxPool":
k = get_by_name(n.attribute, "kernel_shape").ints[-1]
stride = get_by_name(n.attribute, "strides").ints[-1]
elif n.op_type == "QuantAvgPool2d":
inst = getCustomOp(n)
k = inst.get_nodeattr("kernel")
stride = inst.get_nodeattr("stride")
if k < stride:
continue
elif k == stride:
warnings.warn(
"""Inferring Pool_Batch node for k == stride.
This case can be optimized.
For example, for MaxPool run InferStreamingMaxPool before
InferPool_Batch """
)
try:
pad = get_by_name(n.attribute, "pads").ints[-1]
......@@ -203,10 +217,16 @@ class InferPool_Batch(Transformation):
node_input = n.input[0]
node_output = n.output[0]
idt = model.get_tensor_datatype(node_input)
if not idt.is_integer():
continue
# odt = model.get_tensor_datatype(node_output)
# if idt.signed() and n.op_type == "MaxPool":
# # No support for signed input (see accu initialization
# # in Pool_batch HLSLIB function)
# continue
odt = model.get_tensor_datatype(node_output)
ifm_ch = model.get_tensor_shape(n.input[0])[1] # assume NCHW
ofm_ch = ifm_ch
......@@ -246,9 +266,22 @@ class InferPool_Batch(Transformation):
"Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]
)
accum_bits = 0
pool_size_param = k
pad_value = 0
if n.op_type == "MaxPool":
pool_fxn = "MaxPool"
odt = idt
pad_value = idt.min()
elif n.op_type == "QuantAvgPool2d":
assert odt.is_integer(), """Output data type for QuantAvgPool2d
needs to be integer"""
assert pad == 0, "Padding is not supported for QuantAvgPool2d"
inst = getCustomOp(n)
pool_fxn = "QuantAvgPool"
pool_size_param = inst.get_shifts()
accum_bits = inst.get_accum_size()
else:
raise Exception(
"pad_value and pool_fxn not configured for {}".format(n.op_type)
......@@ -278,12 +311,15 @@ class InferPool_Batch(Transformation):
[pool_output],
domain="finn",
backend="fpgadataflow",
dataType=idt.name,
InputDataType=idt.name,
OutputDataType=odt.name,
Channels=ifm_ch,
PE=ifm_ch,
KernelSize=k,
Function=pool_fxn,
OutImgDim=ofm_dim,
AccumBits=accum_bits,
Size=pool_size_param,
BatchSize=1,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment