diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py index 10ee30f89ab6313a4e5863890483e1da3f76e511..a5806748361e42fc4c24cf1db90ee898ce13ccf4 100644 --- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py @@ -106,75 +106,95 @@ class VectorVectorActivation(HLSCustomOp): def minimize_accumulator_width(self, model): """Minimize the accumulator bit width according to the weight values, input data types, and size of dot product""" - if not self.get_nodeattr("runtime_writeable_weights"): - weights = model.get_initializer(self.onnx_node.input[1]) - k_h, k_w = self.get_nodeattr("Kernel") - fm = self.get_nodeattr("Channels") - # put weights into the shape expected by calculate_matvec_accumulator_range - weights = weights.reshape(fm, k_h * k_w).transpose() - if len(self.onnx_node.input) > 2: - thresholds = model.get_initializer(self.onnx_node.input[2]) - else: - thresholds = None - idt = self.get_input_datatype() - # calculate minimum and maximum values of accumulator according to the - # weight values using the bounds derived in https://arxiv.org/abs/2301.13376 + weights = model.get_initializer(self.onnx_node.input[1]) + k_h, k_w = self.get_nodeattr("Kernel") + fm = self.get_nodeattr("Channels") + # put weights into the shape expected by calculate_matvec_accumulator_range + weights = weights.reshape(fm, k_h * k_w).transpose() + # since in the calculation the values of the weight matrix are used, + # for the bipolar case they need to be converted to bipolar + if self.get_nodeattr("binaryXnorMode"): + weights = 2 * weights - 1 + if len(self.onnx_node.input) > 2: + thresholds = model.get_initializer(self.onnx_node.input[2]) + else: + thresholds = None + idt = self.get_input_datatype() + # if runtime-writeable weights, then the values of the weights can + # change and we need to use the worst-case values from the datatypes + if self.get_nodeattr("runtime_writeable_weights"): + wdt = self.get_weight_datatype() + lower_worst = wdt.min() * np.ones_like(weights) + lower_range = calculate_matvec_accumulator_range(lower_worst, idt) + upper_worst = wdt.min() * np.ones_like(weights) + upper_range = calculate_matvec_accumulator_range(upper_worst, idt) + acc_min = min(min(lower_range), min(upper_range)) + acc_max = max(max(upper_range), max(upper_range)) + thresholds = None # range of thresholds are also runtime-writeable + # if not runtime-writeable weights, then we can calculate the min + # and max values of the accumulation range using knowledge of the + # weights and input data types since they are fixed + else: (acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt) - if thresholds is not None: - threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) - # set threshold datatype (and accumulator datatype implicitly) + # if the thresholds can be used to determine range, then adjust the range + # according to the known values of the thresholds + if thresholds is not None: + threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) + # set threshold datatype (and accumulator datatype implicitly) + min_threshold = thresholds.min() + max_threshold = thresholds.max() + # clip threshold values + clip_upper = None + clip_lower = None + if max_threshold > acc_max + 1: + clip_upper = acc_max + 1 + if min_threshold < acc_min: + clip_lower = acc_min + if (clip_lower is not None) or (clip_upper is not None): + warnings.warn( + "Clipping some thresholds in %s" % self.onnx_node.name + ) + thresholds = np.clip(thresholds, clip_lower, clip_upper) + model.set_initializer(self.onnx_node.input[2], thresholds) + threshold_tensor = self.get_hls_compatible_threshold_tensor( + thresholds + ) min_threshold = thresholds.min() max_threshold = thresholds.max() - # clip threshold values - clip_upper = None - clip_lower = None - if max_threshold > acc_max + 1: - clip_upper = acc_max + 1 - if min_threshold < acc_min: - clip_lower = acc_min - if (clip_lower is not None) or (clip_upper is not None): - warnings.warn( - "Clipping some thresholds in %s" % self.onnx_node.name - ) - thresholds = np.clip(thresholds, clip_lower, clip_upper) - model.set_initializer(self.onnx_node.input[2], thresholds) - threshold_tensor = self.get_hls_compatible_threshold_tensor( - thresholds - ) - min_threshold = thresholds.min() - max_threshold = thresholds.max() - # get range required by threshold values - tdt_min = min(acc_min, min_threshold) - tdt_max = max(acc_max, max_threshold) - if tdt_min < 0: - if abs(tdt_min) > tdt_max: - tdt = DataType.get_smallest_possible(tdt_min) - else: - tdt = DataType.get_smallest_possible(-tdt_max - 1) + # get range required by threshold values + tdt_min = min(acc_min, min_threshold) + tdt_max = max(acc_max, max_threshold) + if tdt_min < 0: + if abs(tdt_min) > tdt_max: + tdt = DataType.get_smallest_possible(tdt_min) else: - tdt = DataType.get_smallest_possible(tdt_max) - assert np.vectorize(tdt.allowed)( - threshold_tensor - ).all(), "Thresholds in %s can't be expressed with type %s" % ( - self.onnx_node.name, - str(tdt), - ) - self.set_nodeattr("accDataType", tdt.name) + tdt = DataType.get_smallest_possible(-tdt_max - 1) else: - if acc_min < 0: - if abs(acc_min) > acc_max: - adt = DataType.get_smallest_possible(acc_min) - else: - adt = DataType.get_smallest_possible(-acc_max - 1) + tdt = DataType.get_smallest_possible(tdt_max) + assert np.vectorize(tdt.allowed)( + threshold_tensor + ).all(), "Thresholds in %s can't be expressed with type %s" % ( + self.onnx_node.name, + str(tdt), + ) + adt = tdt # Set activation datatype to the threshold datatype + else: + if acc_min < 0: + if abs(acc_min) > acc_max: + adt = DataType.get_smallest_possible(acc_min) else: - adt = DataType.get_smallest_possible(acc_max) - # ensure a datatype divisible by 8-bits in case this is the last node - bw = roundup_to_integer_multiple(adt.bitwidth(), 8) - new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw)) - adt = DataType[new_adt_name] - self.set_nodeattr("accDataType", adt.name) - # for no-activation nodes, output dt = acc dt - self.set_nodeattr("outputDataType", adt.name) + adt = DataType.get_smallest_possible(-acc_max - 1) + else: + adt = DataType.get_smallest_possible(acc_max) + # if this is the last node in the graph, then ensure the datatype is + # divisibly by 8 bits + if model.find_direct_successors(self.onnx_node) is None: + bw = roundup_to_integer_multiple(adt.bitwidth(), 8) + new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw)) + adt = DataType[new_adt_name] + # for no-activation nodes, output dt = acc dt + self.set_nodeattr("outputDataType", adt.name) + self.set_nodeattr("accDataType", adt.name) return DataType[self.get_nodeattr("accDataType")] def minimize_weight_bit_width(self, model):