From 61ac5b62e4da00837542d814c656798930deb6bf Mon Sep 17 00:00:00 2001
From: Felix Jentzsch <felix.jentzsch@upb.de>
Date: Thu, 26 Jan 2023 14:01:41 +0100
Subject: [PATCH] [VVAU] update resource estimates

---
 .../fpgadataflow/vectorvectoractivation.py    | 61 +++++++++++++------
 1 file changed, 43 insertions(+), 18 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
index 72158ffcd..2e86d72d0 100644
--- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
+++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
@@ -218,6 +218,10 @@ class VectorVectorActivation(HLSCustomOp):
         """Returns FINN DataType of weights."""
         return DataType[self.get_nodeattr("weightDataType")]
 
+    def get_accumulator_datatype(self):
+        """Returns FINN DataType of accumulator"""
+        return DataType[self.get_nodeattr("accDataType")]
+
     def get_output_datatype(self, ind=0):
         """Returns FINN DataType of output."""
         return DataType[self.get_nodeattr("outputDataType")]
@@ -1115,7 +1119,7 @@ class VectorVectorActivation(HLSCustomOp):
 
     def uram_estimation(self):
         P = self.get_nodeattr("PE")
-        Q = 1
+        Q = self.get_nodeattr("SIMD")
         wdt = self.get_weight_datatype()
         W = wdt.bitwidth()
         omega = self.calc_wmem()
@@ -1124,7 +1128,7 @@ class VectorVectorActivation(HLSCustomOp):
         mstyle = self.get_nodeattr("ram_style")
         if (
             (mmode == "decoupled" and mstyle != "ultra")
-            or (mmode == "const" and self.calc_wmem() <= 128)
+            or (mmode == "const")
             or (mmode == "external")
         ):
             return 0
@@ -1136,9 +1140,11 @@ class VectorVectorActivation(HLSCustomOp):
         """Calculates resource estimation for BRAM"""
         # TODO add in/out FIFO contributions
         P = self.get_nodeattr("PE")
+        Q = self.get_nodeattr("SIMD")
         wdt = self.get_weight_datatype()
         W = wdt.bitwidth()
         omega = self.calc_wmem()
+        mem_width = Q * W * P
         # assuming SDP mode RAMB18s (see UG573 Table 1-10)
         # since this is HLS memory, not using the full width of a BRAM
         # assuming memories up to 128 deep get implemented in LUTs
@@ -1146,23 +1152,24 @@ class VectorVectorActivation(HLSCustomOp):
         mstyle = self.get_nodeattr("ram_style")
         if (
             (mmode == "decoupled" and mstyle in ["distributed", "ultra"])
+            or (mstyle == "auto" and self.calc_wmem() <= 128)
             or (mmode == "const" and self.calc_wmem() <= 128)
             or (mmode == "external")
         ):
             return 0
 
-        if W == 1:
-            return math.ceil(omega / 16384) * P
-        elif W == 2:
-            return math.ceil(omega / 8192) * P
-        elif W <= 4:
-            return (math.ceil(omega / 4096)) * (math.ceil(W / 4)) * P
-        elif W <= 9:
-            return (math.ceil(omega / 2048)) * (math.ceil(W / 8)) * P
-        elif W <= 18 or omega > 512:
-            return (math.ceil(omega / 1024)) * (math.ceil(W / 16)) * P
+        if mem_width == 1:
+            return math.ceil(omega / 16384)
+        elif mem_width == 2:
+            return math.ceil(omega / 8192)
+        elif mem_width <= 4:
+            return (math.ceil(omega / 4096)) * (math.ceil(mem_width / 4))
+        elif mem_width <= 9:
+            return (math.ceil(omega / 2048)) * (math.ceil(mem_width / 8))
+        elif mem_width <= 18 or omega > 512:
+            return (math.ceil(omega / 1024)) * (math.ceil(mem_width / 16))
         else:
-            return (math.ceil(omega / 512)) * (math.ceil(W / 32)) * P
+            return (math.ceil(omega / 512)) * (math.ceil(mem_width / 32))
 
     def bram_efficiency_estimation(self):
         P = self.get_nodeattr("PE")
@@ -1186,6 +1193,7 @@ class VectorVectorActivation(HLSCustomOp):
         """
         # TODO add in/out FIFO contributions
         P = self.get_nodeattr("PE")
+        Q = self.get_nodeattr("SIMD")
         wdt = self.get_weight_datatype()
         W = wdt.bitwidth()
         # determine tdt with input and weight data types
@@ -1200,29 +1208,46 @@ class VectorVectorActivation(HLSCustomOp):
         if (mmode == "decoupled" and mstyle == "distributed") or (
             mmode == "const" and self.calc_wmem() <= 128
         ):
-            c2 = (P * W) * math.ceil(self.calc_wmem() / 64)
+            c2 = (P * Q * W) * math.ceil(self.calc_wmem() / 64)
 
         # multiplication
         res_type = self.get_nodeattr("resType")
         if res_type == "dsp":
             mult_luts = 0
         else:
-            mult_luts = (2 * math.ceil((W + A) / 6) - 1) * (W + A)
+            mult_luts = Q * (2 * math.ceil((W + A) / 6) - 1) * (W + A)
+        # adder tree
+        addertree_luts = (W + A) * (2 * Q - 1)
         # accumulator
+        acc_datatype = self.get_accumulator_datatype()
+        acc_bits = acc_datatype.bitwidth()
         k_h, k_w = self.get_nodeattr("Kernel")
-        acc_bits = W + A + math.ceil(math.log(k_h * k_w, 2))
+        # if accDataType is not set, then it will default to INT32, which would
+        # be a large overestimate in most (if not all) cases. In this scenario,
+        # we would use the minimum accumulator as determined by the data types.
+        alpha = math.log(k_h * k_w, 2) + W + A - 1 - int(idt.signed())
+
+        def phi(x_):
+            return math.log(1 + pow(2, -x_), 2)
+
+        acc_bits = min(acc_datatype.bitwidth(), np.ceil(alpha + phi(alpha) + 1))
         acc_luts = acc_bits
         # thresholds and threshold comparators
         thr_luts = 0
         comp_luts = 0
         noact = self.get_nodeattr("noActivation")
+        # TODO - add 'ram_style_threshold' node attribute
         if noact == 0:
             odt = self.get_output_datatype()
             B = odt.bitwidth()
-            thr_luts = (2**B - 1) * acc_bits * math.ceil(self.calc_tmem() / 64)
+            thr_luts = (2**B - 1) * acc_bits * self.calc_tmem() / 64
             comp_luts = (2**B - 1) * acc_bits
 
-        return int(c0 + c1 * (P * (mult_luts + acc_luts + thr_luts + comp_luts)) + c2)
+        return int(
+            c0
+            + c1 * (P * (mult_luts + addertree_luts + acc_luts + thr_luts + comp_luts))
+            + c2
+        )
 
     def dsp_estimation(self):
         # multiplication
-- 
GitLab