diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
index 4841ce77baca5bd8c69c8acf6b1b102d17fb1d69..075d807c0a7686d452ba57140e1fec2115954e01 100644
--- a/src/finn/custom_op/quantavgpool2d.py
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -59,21 +59,17 @@ class QuantAvgPool2d(CustomOp):
             outputs=[outp],
         )
         model_avgpool = helper.make_model(graph_avgpool)
-        idict = {node.input[0]: context[node.input[0]]}
+        idict = {node.input[0]: np.round(context[node.input[0]])}
         sess = rt.InferenceSession(model_avgpool.SerializeToString())
         result_temp = sess.run(None, idict)
         # remove scaling introduced by average
         result_temp = result_temp[0] * (k * k)
-        scale = context[node.input[1]]
-        result_temp = np.round(result_temp / scale) * scale
         ibits = self.get_nodeattr("ibits")
         max_value = 2 ** ibits - 1
         max_value = max_value * k * k
         max_bit_width = int(max_value).bit_length()
         shift_bits = max_bit_width - self.get_nodeattr("obits")
-        trunc_scale = 2.0 ** shift_bits
-        output_scale = trunc_scale * scale
-        result = np.floor(result_temp / output_scale) * scale
+        result = np.right_shift(result_temp.astype(int), shift_bits)
         context[node.output[0]] = result.astype(np.float32)
 
     def verify_node(self):