Skip to content
Snippets Groups Projects
Commit df2d15e2 authored by auphelia's avatar auphelia
Browse files

[CustomOp] Update execute node of QuantAvgPool2d

parent b8f3d225
No related branches found
No related tags found
No related merge requests found
......@@ -59,21 +59,17 @@ class QuantAvgPool2d(CustomOp):
outputs=[outp],
)
model_avgpool = helper.make_model(graph_avgpool)
idict = {node.input[0]: context[node.input[0]]}
idict = {node.input[0]: np.round(context[node.input[0]])}
sess = rt.InferenceSession(model_avgpool.SerializeToString())
result_temp = sess.run(None, idict)
# remove scaling introduced by average
result_temp = result_temp[0] * (k * k)
scale = context[node.input[1]]
result_temp = np.round(result_temp / scale) * scale
ibits = self.get_nodeattr("ibits")
max_value = 2 ** ibits - 1
max_value = max_value * k * k
max_bit_width = int(max_value).bit_length()
shift_bits = max_bit_width - self.get_nodeattr("obits")
trunc_scale = 2.0 ** shift_bits
output_scale = trunc_scale * scale
result = np.floor(result_temp / output_scale) * scale
result = np.right_shift(result_temp.astype(int), shift_bits)
context[node.output[0]] = result.astype(np.float32)
def verify_node(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment