From 7efbb088f4673f98cda8e6a1239b367dd041c0dd Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Fri, 5 Jun 2020 16:09:06 +0100
Subject: [PATCH] [CustomOp] Change calculation for execution slightly in
 QuantAvgPool2d

---
 src/finn/custom_op/quantavgpool2d.py | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
index b490e48b7..5f918821f 100644
--- a/src/finn/custom_op/quantavgpool2d.py
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -14,9 +14,9 @@ class QuantAvgPool2d(CustomOp):
         return {
             "stride": ("i", True, 1),
             "kernel": ("i", True, 1),
-            "ibits": ("s", True, ""),
-            "obits": ("i", False, 0),
-            "signed": ("i", False, 0),
+            "ibits": ("i", True, 1),
+            "obits": ("i", True, 1),
+            "signed": ("i", True, 0),
         }
 
     def make_shape_compatible_op(self, model):
@@ -84,11 +84,10 @@ class QuantAvgPool2d(CustomOp):
         result_temp = (result_temp[0] * (k * k)).astype(int)
         max_value = np.max(result_temp)
         max_bit_width = int(max_value).bit_length()
-        shift_bits = max_bit_width - self.get_nodeattr("obits")
+        shift_bits = max_bit_width - self.get_nodeattr("obits") + 1
         shift_array = np.ones(result_temp.shape, dtype=np.int) * shift_bits
         result = np.right_shift(result_temp, shift_array)
-
-        context[node.output[0]] = result
+        context[node.output[0]] = result.astype(np.float32)
 
     def verify_node(self):
         pass
-- 
GitLab