From aaf03f5738d2daada44851f39e2442db1d44f9a2 Mon Sep 17 00:00:00 2001
From: icolbert <Ian.Colbert@amd.com>
Date: Thu, 1 Dec 2022 08:01:25 -0800
Subject: [PATCH] Updating VVAU LUT estimation

- Using accDataType rather than an estimate
- Updated the estimate equation for case when accDataType is not specified
- Adding logic check that thresholds are also using LUTRAM rather than BRAM
---
 .../fpgadataflow/vectorvectoractivation.py    | 19 +++++++++++++++++--
 1 file changed, 17 insertions(+), 2 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
index a411d245a..a0b926895 100644
--- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
+++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
@@ -216,6 +216,10 @@ class VectorVectorActivation(HLSCustomOp):
         """Returns FINN DataType of weights."""
         return DataType[self.get_nodeattr("weightDataType")]
 
+    def get_accumulator_datatype(self):
+        """Returns FINN DataType of accumulator"""
+        return DataType[self.get_nodeattr("accDataType")]
+
     def get_output_datatype(self, ind=0):
         """Returns FINN DataType of output."""
         return DataType[self.get_nodeattr("outputDataType")]
@@ -1172,14 +1176,25 @@ class VectorVectorActivation(HLSCustomOp):
         else:
             mult_luts = (2 * math.ceil((W + A) / 6) - 1) * (W + A)
         # accumulator
+        acc_datatype = self.get_accumulator_datatype()
+        acc_bits = acc_datatype.bitwidth()
         k_h, k_w = self.get_nodeattr("Kernel")
-        acc_bits = W + A + math.ceil(math.log(k_h * k_w, 2))
+        # if accDataType is not set, then it will default to INT32, which would
+        # be a large overestimate in most (if not all) cases. In this scenario,
+        # we would use the minimum accumulator as determined by the data types.
+        alpha = math.log(k_h * k_w, 2) + W + A - 1 - int(idt.signed())
+        phi = lambda x_: math.log(1 + pow(2, -x_), 2)
+        acc_bits = min(
+            acc_datatype.bitwidth(),
+            np.ceil(alpha + phi(alpha) + 1)
+        )
         acc_luts = acc_bits
         # thresholds and threshold comparators
         thr_luts = 0
         comp_luts = 0
         noact = self.get_nodeattr("noActivation")
-        if noact == 0:
+        tmem_style = self.get_nodeattr("ram_style_thresholds")
+        if (noact == 0) and (tmem_style == "distributed"):
             odt = self.get_output_datatype()
             B = odt.bitwidth()
             thr_luts = (2**B - 1) * acc_bits * math.ceil(self.calc_tmem() / 64)
-- 
GitLab