From 7afc09862707074982f3e18da6e019c3614a9442 Mon Sep 17 00:00:00 2001
From: icolbert <Ian.Colbert@amd.com>
Date: Mon, 27 Feb 2023 15:09:37 -0800
Subject: [PATCH] Update minimize_accumulator_width for VVAU

---
 .../fpgadataflow/vectorvectoractivation.py    | 146 ++++++++++--------
 1 file changed, 83 insertions(+), 63 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
index 10ee30f89..a58067483 100644
--- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
+++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
@@ -106,75 +106,95 @@ class VectorVectorActivation(HLSCustomOp):
     def minimize_accumulator_width(self, model):
         """Minimize the accumulator bit width according to the weight values,
         input data types, and size of dot product"""
-        if not self.get_nodeattr("runtime_writeable_weights"):
-            weights = model.get_initializer(self.onnx_node.input[1])
-            k_h, k_w = self.get_nodeattr("Kernel")
-            fm = self.get_nodeattr("Channels")
-            # put weights into the shape expected by calculate_matvec_accumulator_range
-            weights = weights.reshape(fm, k_h * k_w).transpose()
-            if len(self.onnx_node.input) > 2:
-                thresholds = model.get_initializer(self.onnx_node.input[2])
-            else:
-                thresholds = None
-            idt = self.get_input_datatype()
-            # calculate minimum and maximum values of accumulator according to the
-            # weight values using the bounds derived in https://arxiv.org/abs/2301.13376
+        weights = model.get_initializer(self.onnx_node.input[1])
+        k_h, k_w = self.get_nodeattr("Kernel")
+        fm = self.get_nodeattr("Channels")
+        # put weights into the shape expected by calculate_matvec_accumulator_range
+        weights = weights.reshape(fm, k_h * k_w).transpose()
+        # since in the calculation the values of the weight matrix are used,
+        # for the bipolar case they need to be converted to bipolar
+        if self.get_nodeattr("binaryXnorMode"):
+            weights = 2 * weights - 1
+        if len(self.onnx_node.input) > 2:
+            thresholds = model.get_initializer(self.onnx_node.input[2])
+        else:
+            thresholds = None
+        idt = self.get_input_datatype()
+        # if runtime-writeable weights, then the values of the weights can
+        # change and we need to use the worst-case values from the datatypes
+        if self.get_nodeattr("runtime_writeable_weights"):
+            wdt = self.get_weight_datatype()
+            lower_worst = wdt.min() * np.ones_like(weights)
+            lower_range = calculate_matvec_accumulator_range(lower_worst, idt)
+            upper_worst = wdt.min() * np.ones_like(weights)
+            upper_range = calculate_matvec_accumulator_range(upper_worst, idt)
+            acc_min = min(min(lower_range), min(upper_range))
+            acc_max = max(max(upper_range), max(upper_range))
+            thresholds = None # range of thresholds are also runtime-writeable
+        # if not runtime-writeable weights, then we can calculate the min
+        # and max values of the accumulation range using knowledge of the
+        # weights and input data types since they are fixed
+        else:
             (acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt)
-            if thresholds is not None:
-                threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
-                # set threshold datatype (and accumulator datatype implicitly)
+        # if the thresholds can be used to determine range, then adjust the range
+        # according to the known values of the thresholds
+        if thresholds is not None:
+            threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
+            # set threshold datatype (and accumulator datatype implicitly)
+            min_threshold = thresholds.min()
+            max_threshold = thresholds.max()
+            # clip threshold values
+            clip_upper = None
+            clip_lower = None
+            if max_threshold > acc_max + 1:
+                clip_upper = acc_max + 1
+            if min_threshold < acc_min:
+                clip_lower = acc_min
+            if (clip_lower is not None) or (clip_upper is not None):
+                warnings.warn(
+                    "Clipping some thresholds in %s" % self.onnx_node.name
+                )
+                thresholds = np.clip(thresholds, clip_lower, clip_upper)
+                model.set_initializer(self.onnx_node.input[2], thresholds)
+                threshold_tensor = self.get_hls_compatible_threshold_tensor(
+                    thresholds
+                )
                 min_threshold = thresholds.min()
                 max_threshold = thresholds.max()
-                # clip threshold values
-                clip_upper = None
-                clip_lower = None
-                if max_threshold > acc_max + 1:
-                    clip_upper = acc_max + 1
-                if min_threshold < acc_min:
-                    clip_lower = acc_min
-                if (clip_lower is not None) or (clip_upper is not None):
-                    warnings.warn(
-                        "Clipping some thresholds in %s" % self.onnx_node.name
-                    )
-                    thresholds = np.clip(thresholds, clip_lower, clip_upper)
-                    model.set_initializer(self.onnx_node.input[2], thresholds)
-                    threshold_tensor = self.get_hls_compatible_threshold_tensor(
-                        thresholds
-                    )
-                    min_threshold = thresholds.min()
-                    max_threshold = thresholds.max()
-                # get range required by threshold values
-                tdt_min = min(acc_min, min_threshold)
-                tdt_max = max(acc_max, max_threshold)
-                if tdt_min < 0:
-                    if abs(tdt_min) > tdt_max:
-                        tdt = DataType.get_smallest_possible(tdt_min)
-                    else:
-                        tdt = DataType.get_smallest_possible(-tdt_max - 1)
+            # get range required by threshold values
+            tdt_min = min(acc_min, min_threshold)
+            tdt_max = max(acc_max, max_threshold)
+            if tdt_min < 0:
+                if abs(tdt_min) > tdt_max:
+                    tdt = DataType.get_smallest_possible(tdt_min)
                 else:
-                    tdt = DataType.get_smallest_possible(tdt_max)
-                assert np.vectorize(tdt.allowed)(
-                    threshold_tensor
-                ).all(), "Thresholds in %s can't be expressed with type %s" % (
-                    self.onnx_node.name,
-                    str(tdt),
-                )
-                self.set_nodeattr("accDataType", tdt.name)
+                    tdt = DataType.get_smallest_possible(-tdt_max - 1)
             else:
-                if acc_min < 0:
-                    if abs(acc_min) > acc_max:
-                        adt = DataType.get_smallest_possible(acc_min)
-                    else:
-                        adt = DataType.get_smallest_possible(-acc_max - 1)
+                tdt = DataType.get_smallest_possible(tdt_max)
+            assert np.vectorize(tdt.allowed)(
+                threshold_tensor
+            ).all(), "Thresholds in %s can't be expressed with type %s" % (
+                self.onnx_node.name,
+                str(tdt),
+            )
+            adt = tdt # Set activation datatype to the threshold datatype
+        else:
+            if acc_min < 0:
+                if abs(acc_min) > acc_max:
+                    adt = DataType.get_smallest_possible(acc_min)
                 else:
-                    adt = DataType.get_smallest_possible(acc_max)
-                # ensure a datatype divisible by 8-bits in case this is the last node
-                bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
-                new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
-                adt = DataType[new_adt_name]
-                self.set_nodeattr("accDataType", adt.name)
-                # for no-activation nodes, output dt = acc dt
-                self.set_nodeattr("outputDataType", adt.name)
+                    adt = DataType.get_smallest_possible(-acc_max - 1)
+            else:
+                adt = DataType.get_smallest_possible(acc_max)
+        # if this is the last node in the graph, then ensure the datatype is
+        # divisibly by 8 bits
+        if model.find_direct_successors(self.onnx_node) is None:
+            bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
+            new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
+            adt = DataType[new_adt_name]
+            # for no-activation nodes, output dt = acc dt
+            self.set_nodeattr("outputDataType", adt.name)
+        self.set_nodeattr("accDataType", adt.name)
         return DataType[self.get_nodeattr("accDataType")]
 
     def minimize_weight_bit_width(self, model):
-- 
GitLab