diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index 6244bbc8e72c744be387ae5fcd2849092ae85b4b..a1dff7a0ad798189e732a4da3554769bce0a503e 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -651,17 +651,20 @@ class MatrixVectorActivation(HLSCustomOp): return DataType[self.get_nodeattr("accDataType")] def minimize_weight_bit_width(self, model): - weights = model.get_initializer(self.onnx_node.input[1]) - w_min = weights.min() - w_max = weights.max() - if w_min < 0: - if abs(w_min) > w_max: - wdt = DataType.get_smallest_possible(w_min) + """Minimize the bit width based on the values of the weights""" + runtime_writable = self.get_nodeattr("runtime_writeable_weights") == 0 + if runtime_writable: + weights = model.get_initializer(self.onnx_node.input[1]) + w_min = weights.min() + w_max = weights.max() + if w_min < 0: + if abs(w_min) > w_max: + wdt = DataType.get_smallest_possible(w_min) + else: + wdt = DataType.get_smallest_possible(-w_max - 1) else: - wdt = DataType.get_smallest_possible(-w_max - 1) - else: - wdt = DataType.get_smallest_possible(w_max) - self.set_nodeattr("weightDataType", wdt.name) + wdt = DataType.get_smallest_possible(w_max) + self.set_nodeattr("weightDataType", wdt.name) return DataType[self.get_nodeattr("weightDataType")] def get_hls_compatible_threshold_tensor(self, orig_thres_matrix): diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py index 665ff71810168e38aa172ce3bfe01d7a2186fc84..5d97244e5b59b8bf17b42aa57b2a789772f6adb5 100644 --- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py @@ -170,17 +170,20 @@ class VectorVectorActivation(HLSCustomOp): return DataType[self.get_nodeattr("accDataType")] def minimize_weight_bit_width(self, model): - weights = model.get_initializer(self.onnx_node.input[1]) - w_min = weights.min() - w_max = weights.max() - if w_min < 0: - if abs(w_min) > w_max: - wdt = DataType.get_smallest_possible(w_min) + """Minimize the bit width based on the values of the weights""" + runtime_writable = self.get_nodeattr("runtime_writeable_weights") == 0 + if runtime_writable: + weights = model.get_initializer(self.onnx_node.input[1]) + w_min = weights.min() + w_max = weights.max() + if w_min < 0: + if abs(w_min) > w_max: + wdt = DataType.get_smallest_possible(w_min) + else: + wdt = DataType.get_smallest_possible(-w_max - 1) else: - wdt = DataType.get_smallest_possible(-w_max - 1) - else: - wdt = DataType.get_smallest_possible(w_max) - self.set_nodeattr("weightDataType", wdt.name) + wdt = DataType.get_smallest_possible(w_max) + self.set_nodeattr("weightDataType", wdt.name) return DataType[self.get_nodeattr("weightDataType")] def calc_wmem(self):