From f353feffc71a7918c75b1b91e11e111dd7ced539 Mon Sep 17 00:00:00 2001
From: icolbert <Ian.Colbert@amd.com>
Date: Fri, 6 Jan 2023 13:36:28 -0800
Subject: [PATCH] Adding check for runtime_writeable_weights

---
 .../fpgadataflow/matrixvectoractivation.py    | 23 +++++++++++--------
 .../fpgadataflow/vectorvectoractivation.py    | 23 +++++++++++--------
 2 files changed, 26 insertions(+), 20 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py
index 6244bbc8e..a1dff7a0a 100644
--- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py
+++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py
@@ -651,17 +651,20 @@ class MatrixVectorActivation(HLSCustomOp):
         return DataType[self.get_nodeattr("accDataType")]
 
     def minimize_weight_bit_width(self, model):
-        weights = model.get_initializer(self.onnx_node.input[1])
-        w_min = weights.min()
-        w_max = weights.max()
-        if w_min < 0:
-            if abs(w_min) > w_max:
-                wdt = DataType.get_smallest_possible(w_min)
+        """Minimize the bit width based on the values of the weights"""
+        runtime_writable = self.get_nodeattr("runtime_writeable_weights") == 0
+        if runtime_writable:
+            weights = model.get_initializer(self.onnx_node.input[1])
+            w_min = weights.min()
+            w_max = weights.max()
+            if w_min < 0:
+                if abs(w_min) > w_max:
+                    wdt = DataType.get_smallest_possible(w_min)
+                else:
+                    wdt = DataType.get_smallest_possible(-w_max - 1)
             else:
-                wdt = DataType.get_smallest_possible(-w_max - 1)
-        else:
-            wdt = DataType.get_smallest_possible(w_max)
-        self.set_nodeattr("weightDataType", wdt.name)
+                wdt = DataType.get_smallest_possible(w_max)
+            self.set_nodeattr("weightDataType", wdt.name)
         return DataType[self.get_nodeattr("weightDataType")]
 
     def get_hls_compatible_threshold_tensor(self, orig_thres_matrix):
diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
index 665ff7181..5d97244e5 100644
--- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
+++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
@@ -170,17 +170,20 @@ class VectorVectorActivation(HLSCustomOp):
         return DataType[self.get_nodeattr("accDataType")]
 
     def minimize_weight_bit_width(self, model):
-        weights = model.get_initializer(self.onnx_node.input[1])
-        w_min = weights.min()
-        w_max = weights.max()
-        if w_min < 0:
-            if abs(w_min) > w_max:
-                wdt = DataType.get_smallest_possible(w_min)
+        """Minimize the bit width based on the values of the weights"""
+        runtime_writable = self.get_nodeattr("runtime_writeable_weights") == 0
+        if runtime_writable:
+            weights = model.get_initializer(self.onnx_node.input[1])
+            w_min = weights.min()
+            w_max = weights.max()
+            if w_min < 0:
+                if abs(w_min) > w_max:
+                    wdt = DataType.get_smallest_possible(w_min)
+                else:
+                    wdt = DataType.get_smallest_possible(-w_max - 1)
             else:
-                wdt = DataType.get_smallest_possible(-w_max - 1)
-        else:
-            wdt = DataType.get_smallest_possible(w_max)
-        self.set_nodeattr("weightDataType", wdt.name)
+                wdt = DataType.get_smallest_possible(w_max)
+            self.set_nodeattr("weightDataType", wdt.name)
         return DataType[self.get_nodeattr("weightDataType")]
 
     def calc_wmem(self):
-- 
GitLab