diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py index 5f918821fa848a5a432319ba905e07a6e99d4635..44e7d52a3a1af373e06058fef55f7aa3aba51e0f 100644 --- a/src/finn/custom_op/quantavgpool2d.py +++ b/src/finn/custom_op/quantavgpool2d.py @@ -81,12 +81,17 @@ class QuantAvgPool2d(CustomOp): sess = rt.InferenceSession(model_avgpool.SerializeToString()) result_temp = sess.run(None, idict) # remove scaling introduced by average - result_temp = (result_temp[0] * (k * k)).astype(int) - max_value = np.max(result_temp) + result_temp = result_temp[0] * (k * k) + scale = context[node.input[1]] + result_temp = np.round(result_temp / scale) * scale + ibits = self.get_nodeattr("ibits") + max_value = 2 ** ibits - 1 + max_value = max_value * k * k max_bit_width = int(max_value).bit_length() - shift_bits = max_bit_width - self.get_nodeattr("obits") + 1 - shift_array = np.ones(result_temp.shape, dtype=np.int) * shift_bits - result = np.right_shift(result_temp, shift_array) + shift_bits = max_bit_width - self.get_nodeattr("obits") + trunc_scale = 2.0 ** shift_bits + output_scale = trunc_scale * scale + result = np.floor(result_temp / output_scale) * scale context[node.output[0]] = result.astype(np.float32) def verify_node(self):