From a7f4b4a0f06d55079d6ebeb8b29230a6e6ecc79c Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Tue, 9 Jun 2020 13:55:25 +0100
Subject: [PATCH] [CustomOp] Add node execution functionality to QuantAvgPool2d

---
 src/finn/custom_op/quantavgpool2d.py | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
index 5f918821f..44e7d52a3 100644
--- a/src/finn/custom_op/quantavgpool2d.py
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -81,12 +81,17 @@ class QuantAvgPool2d(CustomOp):
         sess = rt.InferenceSession(model_avgpool.SerializeToString())
         result_temp = sess.run(None, idict)
         # remove scaling introduced by average
-        result_temp = (result_temp[0] * (k * k)).astype(int)
-        max_value = np.max(result_temp)
+        result_temp = result_temp[0] * (k * k)
+        scale = context[node.input[1]]
+        result_temp = np.round(result_temp / scale) * scale
+        ibits = self.get_nodeattr("ibits")
+        max_value = 2 ** ibits - 1
+        max_value = max_value * k * k
         max_bit_width = int(max_value).bit_length()
-        shift_bits = max_bit_width - self.get_nodeattr("obits") + 1
-        shift_array = np.ones(result_temp.shape, dtype=np.int) * shift_bits
-        result = np.right_shift(result_temp, shift_array)
+        shift_bits = max_bit_width - self.get_nodeattr("obits")
+        trunc_scale = 2.0 ** shift_bits
+        output_scale = trunc_scale * scale
+        result = np.floor(result_temp / output_scale) * scale
         context[node.output[0]] = result.astype(np.float32)
 
     def verify_node(self):
-- 
GitLab