diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index c440b3675ca4dcb06329a25b532bd1211db29b87..6aa26af453b61aac1784cbb5f094a9cc6cc608fd 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -353,12 +353,12 @@ class MatrixVectorActivation(HLSCustomOp): acc_datatype = self.get_accumulator_datatype() # if accDataType is not set, then it will default to INT32, which would # be a large overestimate in most (if not all) cases. In this scenario, - # we would use the minimum accumulator as determined by the data types. + # we would use the minimum accumulator as determined by the data types + # bound, derived in https://arxiv.org/abs/2301.13376 alpha = math.log(MW, 2) + W + A - 1 - int(idt.signed()) - phi = lambda x_: math.log(1 + pow(2, -x_), 2) acc_bits = min( acc_datatype.bitwidth(), - np.ceil(alpha + phi(alpha) + 1) + np.ceil(alpha + math.log(1 + pow(2, -alpha), 2) + 1) ) acc_luts = acc_bits # thresholds and threshold comparators diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py index 377a62f79f19db57ea3d1639d167e5307e85d127..796225a712203c5cb311590e478367835d5b15f5 100644 --- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py @@ -1199,12 +1199,12 @@ class VectorVectorActivation(HLSCustomOp): k_h, k_w = self.get_nodeattr("Kernel") # if accDataType is not set, then it will default to INT32, which would # be a large overestimate in most (if not all) cases. In this scenario, - # we would use the minimum accumulator as determined by the data types. + # we would use the minimum accumulator as determined by the data types + # bound, derived in https://arxiv.org/abs/2301.13376 alpha = math.log(k_h * k_w, 2) + W + A - 1 - int(idt.signed()) - phi = lambda x_: math.log(1 + pow(2, -x_), 2) acc_bits = min( acc_datatype.bitwidth(), - np.ceil(alpha + phi(alpha) + 1) + np.ceil(alpha + math.log(1 + pow(2, -alpha), 2) + 1) ) acc_luts = acc_bits # thresholds and threshold comparators