From f353feffc71a7918c75b1b91e11e111dd7ced539 Mon Sep 17 00:00:00 2001 From: icolbert <Ian.Colbert@amd.com> Date: Fri, 6 Jan 2023 13:36:28 -0800 Subject: [PATCH] Adding check for runtime_writeable_weights --- .../fpgadataflow/matrixvectoractivation.py | 23 +++++++++++-------- .../fpgadataflow/vectorvectoractivation.py | 23 +++++++++++-------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index 6244bbc8e..a1dff7a0a 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -651,17 +651,20 @@ class MatrixVectorActivation(HLSCustomOp): return DataType[self.get_nodeattr("accDataType")] def minimize_weight_bit_width(self, model): - weights = model.get_initializer(self.onnx_node.input[1]) - w_min = weights.min() - w_max = weights.max() - if w_min < 0: - if abs(w_min) > w_max: - wdt = DataType.get_smallest_possible(w_min) + """Minimize the bit width based on the values of the weights""" + runtime_writable = self.get_nodeattr("runtime_writeable_weights") == 0 + if runtime_writable: + weights = model.get_initializer(self.onnx_node.input[1]) + w_min = weights.min() + w_max = weights.max() + if w_min < 0: + if abs(w_min) > w_max: + wdt = DataType.get_smallest_possible(w_min) + else: + wdt = DataType.get_smallest_possible(-w_max - 1) else: - wdt = DataType.get_smallest_possible(-w_max - 1) - else: - wdt = DataType.get_smallest_possible(w_max) - self.set_nodeattr("weightDataType", wdt.name) + wdt = DataType.get_smallest_possible(w_max) + self.set_nodeattr("weightDataType", wdt.name) return DataType[self.get_nodeattr("weightDataType")] def get_hls_compatible_threshold_tensor(self, orig_thres_matrix): diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py index 665ff7181..5d97244e5 100644 --- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py @@ -170,17 +170,20 @@ class VectorVectorActivation(HLSCustomOp): return DataType[self.get_nodeattr("accDataType")] def minimize_weight_bit_width(self, model): - weights = model.get_initializer(self.onnx_node.input[1]) - w_min = weights.min() - w_max = weights.max() - if w_min < 0: - if abs(w_min) > w_max: - wdt = DataType.get_smallest_possible(w_min) + """Minimize the bit width based on the values of the weights""" + runtime_writable = self.get_nodeattr("runtime_writeable_weights") == 0 + if runtime_writable: + weights = model.get_initializer(self.onnx_node.input[1]) + w_min = weights.min() + w_max = weights.max() + if w_min < 0: + if abs(w_min) > w_max: + wdt = DataType.get_smallest_possible(w_min) + else: + wdt = DataType.get_smallest_possible(-w_max - 1) else: - wdt = DataType.get_smallest_possible(-w_max - 1) - else: - wdt = DataType.get_smallest_possible(w_max) - self.set_nodeattr("weightDataType", wdt.name) + wdt = DataType.get_smallest_possible(w_max) + self.set_nodeattr("weightDataType", wdt.name) return DataType[self.get_nodeattr("weightDataType")] def calc_wmem(self): -- GitLab