From df2d15e2210683e993e9b1ce34480f55de85b9fb Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Thu, 11 Jun 2020 17:05:00 +0100 Subject: [PATCH] [CustomOp] Update execute node of QuantAvgPool2d --- src/finn/custom_op/quantavgpool2d.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py index 4841ce77b..075d807c0 100644 --- a/src/finn/custom_op/quantavgpool2d.py +++ b/src/finn/custom_op/quantavgpool2d.py @@ -59,21 +59,17 @@ class QuantAvgPool2d(CustomOp): outputs=[outp], ) model_avgpool = helper.make_model(graph_avgpool) - idict = {node.input[0]: context[node.input[0]]} + idict = {node.input[0]: np.round(context[node.input[0]])} sess = rt.InferenceSession(model_avgpool.SerializeToString()) result_temp = sess.run(None, idict) # remove scaling introduced by average result_temp = result_temp[0] * (k * k) - scale = context[node.input[1]] - result_temp = np.round(result_temp / scale) * scale ibits = self.get_nodeattr("ibits") max_value = 2 ** ibits - 1 max_value = max_value * k * k max_bit_width = int(max_value).bit_length() shift_bits = max_bit_width - self.get_nodeattr("obits") - trunc_scale = 2.0 ** shift_bits - output_scale = trunc_scale * scale - result = np.floor(result_temp / output_scale) * scale + result = np.right_shift(result_temp.astype(int), shift_bits) context[node.output[0]] = result.astype(np.float32) def verify_node(self): -- GitLab