Skip to content
Snippets Groups Projects
Commit b291eb06 authored by auphelia's avatar auphelia
Browse files

[CustomOp] Integrate NHWC functionality into QUantAvgPool2d node

parent d1c7a151
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ import onnxruntime as rt
from finn.custom_op import CustomOp
from finn.core.datatype import DataType
from finn.custom_op.maxpoolnhwc import compute_pool_output_dim
class QuantAvgPool2d(CustomOp):
......@@ -17,18 +18,44 @@ class QuantAvgPool2d(CustomOp):
"ibits": ("i", True, 1),
"obits": ("i", True, 1),
"signed": ("i", True, 0),
"data_layout": ("s", False, "NCHW"),
}
def make_shape_compatible_op(self, model):
node = self.onnx_node
iname = node.input[0]
ishape = model.get_tensor_shape(iname)
k = self.get_nodeattr("kernel")
s = self.get_nodeattr("stride")
data_layout = self.get_nodeattr("data_layout")
if data_layout == "NCHW":
(n, c, hi, wi) = ishape
ho = compute_pool_output_dim(hi, k, s)
wo = compute_pool_output_dim(wi, k, s)
oshape = (n, c, ho, wo)
elif data_layout == "NHWC":
(n, hi, wi, c) = ishape
ho = compute_pool_output_dim(hi, k, s)
wo = compute_pool_output_dim(wi, k, s)
oshape = (n, ho, wo, c)
else:
raise Exception(
"""Datalayout for QuantAvgPool2d is set to an unvalid value.
Has to be set to "NCHW" or "NHWC"."""
)
# implement tensor with correct shape
values = np.random.randn(*oshape).astype(np.float32)
return helper.make_node(
"AveragePool",
inputs=[node.input[0]],
"Constant",
inputs=[],
outputs=[node.output[0]],
kernel_shape=[k, k],
strides=[s, s],
value=helper.make_tensor(
name="const_tensor",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(float),
),
)
def infer_node_datatype(self, model):
......@@ -48,8 +75,12 @@ class QuantAvgPool2d(CustomOp):
node = self.onnx_node
k = self.get_nodeattr("kernel")
s = self.get_nodeattr("stride")
ishape = context[node.input[0]].shape
inp_values = context[node.input[0]]
oshape = context[node.output[0]].shape
if self.get_nodeattr("data_layout") == "NHWC":
inp_values = inp_values.transpose(0, 3, 1, 2)
oshape = (context[node.output[0]]).transpose(0, 3, 1, 2).shape
ishape = inp_values.shape
inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape)
outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape)
node_avgpool = helper.make_node(
......@@ -66,7 +97,7 @@ class QuantAvgPool2d(CustomOp):
outputs=[outp],
)
model_avgpool = helper.make_model(graph_avgpool)
idict = {node.input[0]: context[node.input[0]]}
idict = {node.input[0]: inp_values}
sess = rt.InferenceSession(model_avgpool.SerializeToString())
result_temp = sess.run(None, idict)
# remove scaling introduced by average
......@@ -77,7 +108,16 @@ class QuantAvgPool2d(CustomOp):
max_bit_width = int(max_value).bit_length()
shift_bits = max_bit_width - self.get_nodeattr("obits")
result = np.right_shift(result_temp.astype(int), shift_bits)
if self.get_nodeattr("data_layout") == "NHWC":
result = result.transpose(0, 2, 3, 1)
context[node.output[0]] = result.astype(np.float32)
def verify_node(self):
pass
info_messages = []
# verify that "domain" is set to "finn"
domain_value = self.onnx_node.domain
if domain_value == "finn":
info_messages.append("Attribute domain is set correctly")
else:
info_messages.append('Attribute domain should be set to "finn"')
return info_messages
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment