From f17d5f8d1b335d36f38eaa64f14d9364c35fe84d Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Fri, 5 Jun 2020 15:44:59 +0100 Subject: [PATCH] [CustomOp] Add execute_node to QuantAvgPool2d custom op --- src/finn/custom_op/quantavgpool2d.py | 37 +++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py index 359b11e7c..4d2c44b4e 100644 --- a/src/finn/custom_op/quantavgpool2d.py +++ b/src/finn/custom_op/quantavgpool2d.py @@ -1,8 +1,10 @@ import numpy as np from onnx import TensorProto, helper +import onnxruntime as rt from finn.custom_op import CustomOp from finn.custom_op.im2col import compute_conv_output_dim +from finn.core.datatype import DataType class QuantAvgPool2d(CustomOp): @@ -54,7 +56,40 @@ class QuantAvgPool2d(CustomOp): model.set_tensor_datatype(node.output[0], dtype) def execute_node(self, context, graph): - pass + # create a standard average pooling node to help calculate the result + node = self.onnx_node + k = self.get_nodeattr("kernel") + s = self.get_nodeattr("stride") + ishape = context[node.input[0]].shape + oshape = context[node.output[0]].shape + inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape) + outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape) + node_avgpool = helper.make_node( + "AveragePool", + inputs=[node.input[0]], + outputs=[node.output[0]], + kernel_shape=[k, k], + strides=[s, s] + ) + graph_avgpool = helper.make_graph( + nodes=[node_avgpool], + name="single-avgpool-exec", + inputs=[inp], + outputs=[outp], + ) + model_avgpool = helper.make_model(graph_avgpool) + idict = {node.input[0] : context[node.input[0]]} + sess = rt.InferenceSession(model_avgpool.SerializeToString()) + result_temp = sess.run(None, idict) + # remove scaling introduced by average + result_temp = (result_temp[0] * (k * k)).astype(int) + max_value = np.max(result_temp) + max_bit_width = int(max_value).bit_length() + shift_bits = max_bit_width - self.get_nodeattr("obits") + shift_array = np.ones(result_temp.shape, dtype=np.int) * shift_bits + result = np.right_shift(result_temp, shift_array) + + context[node.output[0]] = result def verify_node(self): pass -- GitLab