From 7efbb088f4673f98cda8e6a1239b367dd041c0dd Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Fri, 5 Jun 2020 16:09:06 +0100 Subject: [PATCH] [CustomOp] Change calculation for execution slightly in QuantAvgPool2d --- src/finn/custom_op/quantavgpool2d.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py index b490e48b7..5f918821f 100644 --- a/src/finn/custom_op/quantavgpool2d.py +++ b/src/finn/custom_op/quantavgpool2d.py @@ -14,9 +14,9 @@ class QuantAvgPool2d(CustomOp): return { "stride": ("i", True, 1), "kernel": ("i", True, 1), - "ibits": ("s", True, ""), - "obits": ("i", False, 0), - "signed": ("i", False, 0), + "ibits": ("i", True, 1), + "obits": ("i", True, 1), + "signed": ("i", True, 0), } def make_shape_compatible_op(self, model): @@ -84,11 +84,10 @@ class QuantAvgPool2d(CustomOp): result_temp = (result_temp[0] * (k * k)).astype(int) max_value = np.max(result_temp) max_bit_width = int(max_value).bit_length() - shift_bits = max_bit_width - self.get_nodeattr("obits") + shift_bits = max_bit_width - self.get_nodeattr("obits") + 1 shift_array = np.ones(result_temp.shape, dtype=np.int) * shift_bits result = np.right_shift(result_temp, shift_array) - - context[node.output[0]] = result + context[node.output[0]] = result.astype(np.float32) def verify_node(self): pass -- GitLab