Skip to content
Snippets Groups Projects
Commit b291eb06 authored by auphelia's avatar auphelia
Browse files

[CustomOp] Integrate NHWC functionality into QUantAvgPool2d node

parent d1c7a151
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ import onnxruntime as rt ...@@ -4,6 +4,7 @@ import onnxruntime as rt
from finn.custom_op import CustomOp from finn.custom_op import CustomOp
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.custom_op.maxpoolnhwc import compute_pool_output_dim
class QuantAvgPool2d(CustomOp): class QuantAvgPool2d(CustomOp):
...@@ -17,18 +18,44 @@ class QuantAvgPool2d(CustomOp): ...@@ -17,18 +18,44 @@ class QuantAvgPool2d(CustomOp):
"ibits": ("i", True, 1), "ibits": ("i", True, 1),
"obits": ("i", True, 1), "obits": ("i", True, 1),
"signed": ("i", True, 0), "signed": ("i", True, 0),
"data_layout": ("s", False, "NCHW"),
} }
def make_shape_compatible_op(self, model): def make_shape_compatible_op(self, model):
node = self.onnx_node node = self.onnx_node
iname = node.input[0]
ishape = model.get_tensor_shape(iname)
k = self.get_nodeattr("kernel") k = self.get_nodeattr("kernel")
s = self.get_nodeattr("stride") s = self.get_nodeattr("stride")
data_layout = self.get_nodeattr("data_layout")
if data_layout == "NCHW":
(n, c, hi, wi) = ishape
ho = compute_pool_output_dim(hi, k, s)
wo = compute_pool_output_dim(wi, k, s)
oshape = (n, c, ho, wo)
elif data_layout == "NHWC":
(n, hi, wi, c) = ishape
ho = compute_pool_output_dim(hi, k, s)
wo = compute_pool_output_dim(wi, k, s)
oshape = (n, ho, wo, c)
else:
raise Exception(
"""Datalayout for QuantAvgPool2d is set to an unvalid value.
Has to be set to "NCHW" or "NHWC"."""
)
# implement tensor with correct shape
values = np.random.randn(*oshape).astype(np.float32)
return helper.make_node( return helper.make_node(
"AveragePool", "Constant",
inputs=[node.input[0]], inputs=[],
outputs=[node.output[0]], outputs=[node.output[0]],
kernel_shape=[k, k], value=helper.make_tensor(
strides=[s, s], name="const_tensor",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(float),
),
) )
def infer_node_datatype(self, model): def infer_node_datatype(self, model):
...@@ -48,8 +75,12 @@ class QuantAvgPool2d(CustomOp): ...@@ -48,8 +75,12 @@ class QuantAvgPool2d(CustomOp):
node = self.onnx_node node = self.onnx_node
k = self.get_nodeattr("kernel") k = self.get_nodeattr("kernel")
s = self.get_nodeattr("stride") s = self.get_nodeattr("stride")
ishape = context[node.input[0]].shape inp_values = context[node.input[0]]
oshape = context[node.output[0]].shape oshape = context[node.output[0]].shape
if self.get_nodeattr("data_layout") == "NHWC":
inp_values = inp_values.transpose(0, 3, 1, 2)
oshape = (context[node.output[0]]).transpose(0, 3, 1, 2).shape
ishape = inp_values.shape
inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape) inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape)
outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape) outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape)
node_avgpool = helper.make_node( node_avgpool = helper.make_node(
...@@ -66,7 +97,7 @@ class QuantAvgPool2d(CustomOp): ...@@ -66,7 +97,7 @@ class QuantAvgPool2d(CustomOp):
outputs=[outp], outputs=[outp],
) )
model_avgpool = helper.make_model(graph_avgpool) model_avgpool = helper.make_model(graph_avgpool)
idict = {node.input[0]: context[node.input[0]]} idict = {node.input[0]: inp_values}
sess = rt.InferenceSession(model_avgpool.SerializeToString()) sess = rt.InferenceSession(model_avgpool.SerializeToString())
result_temp = sess.run(None, idict) result_temp = sess.run(None, idict)
# remove scaling introduced by average # remove scaling introduced by average
...@@ -77,7 +108,16 @@ class QuantAvgPool2d(CustomOp): ...@@ -77,7 +108,16 @@ class QuantAvgPool2d(CustomOp):
max_bit_width = int(max_value).bit_length() max_bit_width = int(max_value).bit_length()
shift_bits = max_bit_width - self.get_nodeattr("obits") shift_bits = max_bit_width - self.get_nodeattr("obits")
result = np.right_shift(result_temp.astype(int), shift_bits) result = np.right_shift(result_temp.astype(int), shift_bits)
if self.get_nodeattr("data_layout") == "NHWC":
result = result.transpose(0, 2, 3, 1)
context[node.output[0]] = result.astype(np.float32) context[node.output[0]] = result.astype(np.float32)
def verify_node(self): def verify_node(self):
pass info_messages = []
# verify that "domain" is set to "finn"
domain_value = self.onnx_node.domain
if domain_value == "finn":
info_messages.append("Attribute domain is set correctly")
else:
info_messages.append('Attribute domain should be set to "finn"')
return info_messages
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment