diff --git a/notebooks/end2end_example/bnn-pynq/cnv_end2end_example.ipynb b/notebooks/end2end_example/bnn-pynq/cnv_end2end_example.ipynb index 0018bb27caf101bbff93154f2bd193b78c7b4ccf..73e9f4e6e1f6f01f6d3dc0e934615cb25c70278f 100644 --- a/notebooks/end2end_example/bnn-pynq/cnv_end2end_example.ipynb +++ b/notebooks/end2end_example/bnn-pynq/cnv_end2end_example.ipynb @@ -241,7 +241,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We won't go into too much detail about what happens in each transformation and why they are called in the particular order they are (feel free to visualize the intermediate steps using Netron yourself if you are curious) but here is a brief summmmary:\n", + "We won't go into too much detail about what happens in each transformation and why they are called in the particular order they are (feel free to visualize the intermediate steps using Netron yourself if you are curious) but here is a brief summary:\n", "\n", "* `Streamline` moves floating point scaling and addition operations closer to the input of the nearest thresholding activation and absorbs them into thresholds\n", "* `LowerConvsToMatMul` converts ONNX `Conv` nodes into sequences of `Im2Col, MatMul` nodes as discussed above. `Im2Col` is a custom FINN ONNX high-level node type that implements the sliding window operator.\n", diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index 546564ca4a2469f3485b4f7615584bbc4b41d13c..b4a0374fb84e6322e436d445369a58252f9285f9 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -489,6 +489,8 @@ def step_minimize_bit_width(model: ModelWrapper, cfg: DataflowBuildConfig): if cfg.minimize_bit_width: model = model.transform(MinimizeWeightBitWidth()) model = model.transform(MinimizeAccumulatorWidth()) + # make sure the changed datatypes are propagated through the network + model = model.transform(InferDataTypes()) return model diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index d6285a6f699c774d18998c4c7426ac52362e9dec..aa987384ddf80bd13d417a762c0291f5917b39bf 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -591,75 +591,86 @@ class MatrixVectorActivation(HLSCustomOp): def minimize_accumulator_width(self, model): """Minimize the accumulator bit width according to the weight values, input data types, and size of dot product""" - if not self.get_nodeattr("runtime_writeable_weights"): - weights = model.get_initializer(self.onnx_node.input[1]) - # since in the calculation the values of the weight matrix are used, - # for the bipolar case they need to be converted to bipolar - if self.get_nodeattr("binaryXnorMode"): - weights = 2 * weights - 1 - if len(self.onnx_node.input) > 2: - thresholds = model.get_initializer(self.onnx_node.input[2]) - else: - thresholds = None - idt = self.get_input_datatype() - # calculate minimum and maximum values of accumulator according to the - # weight values using the bounds derived in https://arxiv.org/abs/2301.13376 + weights = model.get_initializer(self.onnx_node.input[1]) + # since in the calculation the values of the weight matrix are used, + # for the bipolar case they need to be converted to bipolar + if self.get_nodeattr("binaryXnorMode"): + weights = 2 * weights - 1 + if len(self.onnx_node.input) > 2: + thresholds = model.get_initializer(self.onnx_node.input[2]) + else: + thresholds = None + idt = self.get_input_datatype() + # if runtime-writeable weights, then the values of the weights can + # change and we need to use the worst-case values from the datatypes + if self.get_nodeattr("runtime_writeable_weights"): + wdt = self.get_weight_datatype() + lower_worst = wdt.min() * np.ones_like(weights) + lower_range = calculate_matvec_accumulator_range(lower_worst, idt) + upper_worst = wdt.max() * np.ones_like(weights) + upper_range = calculate_matvec_accumulator_range(upper_worst, idt) + acc_min = min(min(lower_range), min(upper_range)) + acc_max = max(max(upper_range), max(upper_range)) + # if not runtime-writeable weights, then we can calculate the min + # and max values of the accumulation range using knowledge of the + # weights and input data types since they are fixed + else: (acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt) - if thresholds is not None: + # if the thresholds can be used to determine range, then adjust the range + # according to the known values of the thresholds + if thresholds is not None: + threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) + # set threshold datatype (and accumulator datatype implicitly) + min_threshold = thresholds.min() + max_threshold = thresholds.max() + # clip threshold values + clip_upper = None + clip_lower = None + if max_threshold > acc_max + 1: + clip_upper = acc_max + 1 + if min_threshold < acc_min: + clip_lower = acc_min + if (clip_lower is not None) or (clip_upper is not None): + warnings.warn("Clipping some thresholds in %s" % self.onnx_node.name) + thresholds = np.clip(thresholds, clip_lower, clip_upper) + model.set_initializer(self.onnx_node.input[2], thresholds) threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) - # set threshold datatype (and accumulator datatype implicitly) min_threshold = thresholds.min() max_threshold = thresholds.max() - # clip threshold values - clip_upper = None - clip_lower = None - if max_threshold > acc_max + 1: - clip_upper = acc_max + 1 - if min_threshold < acc_min: - clip_lower = acc_min - if (clip_lower is not None) or (clip_upper is not None): - warnings.warn( - "Clipping some thresholds in %s" % self.onnx_node.name - ) - thresholds = np.clip(thresholds, clip_lower, clip_upper) - model.set_initializer(self.onnx_node.input[2], thresholds) - threshold_tensor = self.get_hls_compatible_threshold_tensor( - thresholds - ) - min_threshold = thresholds.min() - max_threshold = thresholds.max() - # get range required by threshold values - tdt_min = min(acc_min, min_threshold) - tdt_max = max(acc_max, max_threshold) - if tdt_min < 0: - if abs(tdt_min) > tdt_max: - tdt = DataType.get_smallest_possible(tdt_min) - else: - tdt = DataType.get_smallest_possible(-tdt_max - 1) + # get range required by threshold values + tdt_min = min(acc_min, min_threshold) + tdt_max = max(acc_max, max_threshold) + if tdt_min < 0: + if abs(tdt_min) > tdt_max: + tdt = DataType.get_smallest_possible(tdt_min) else: - tdt = DataType.get_smallest_possible(tdt_max) - assert np.vectorize(tdt.allowed)( - threshold_tensor - ).all(), "Thresholds in %s can't be expressed with type %s" % ( - self.onnx_node.name, - str(tdt), - ) - self.set_nodeattr("accDataType", tdt.name) + tdt = DataType.get_smallest_possible(-tdt_max - 1) else: - if acc_min < 0: - if abs(acc_min) > acc_max: - adt = DataType.get_smallest_possible(acc_min) - else: - adt = DataType.get_smallest_possible(-acc_max - 1) + tdt = DataType.get_smallest_possible(tdt_max) + assert np.vectorize(tdt.allowed)( + threshold_tensor + ).all(), "Thresholds in %s can't be expressed with type %s" % ( + self.onnx_node.name, + str(tdt), + ) + adt = tdt # Set activation datatype to the threshold datatype + else: + if acc_min < 0: + if abs(acc_min) > acc_max: + adt = DataType.get_smallest_possible(acc_min) else: - adt = DataType.get_smallest_possible(acc_max) - # ensure a datatype divisible by 8-bits in case this is the last node - bw = roundup_to_integer_multiple(adt.bitwidth(), 8) - new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw)) - adt = DataType[new_adt_name] - self.set_nodeattr("accDataType", adt.name) - # for no-activation nodes, output dt = acc dt - self.set_nodeattr("outputDataType", adt.name) + adt = DataType.get_smallest_possible(-acc_max - 1) + else: + adt = DataType.get_smallest_possible(acc_max) + # if this is the last node in the graph, then ensure the datatype is + # divisibly by 8 bits + if model.find_direct_successors(self.onnx_node) is None: + bw = roundup_to_integer_multiple(adt.bitwidth(), 8) + new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw)) + adt = DataType[new_adt_name] + # for no-activation nodes, output dt = acc dt + self.set_nodeattr("outputDataType", adt.name) + self.set_nodeattr("accDataType", adt.name) return DataType[self.get_nodeattr("accDataType")] def minimize_weight_bit_width(self, model): diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py index a2dd3c75dc3c02faab2465e8ac5c70474560bba5..da79933f26b27d3b84946cacad7a9ededb16003c 100644 --- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py @@ -107,75 +107,90 @@ class VectorVectorActivation(HLSCustomOp): def minimize_accumulator_width(self, model): """Minimize the accumulator bit width according to the weight values, input data types, and size of dot product""" - if not self.get_nodeattr("runtime_writeable_weights"): - weights = model.get_initializer(self.onnx_node.input[1]) - k_h, k_w = self.get_nodeattr("Kernel") - fm = self.get_nodeattr("Channels") - # put weights into the shape expected by calculate_matvec_accumulator_range - weights = weights.reshape(fm, k_h * k_w).transpose() - if len(self.onnx_node.input) > 2: - thresholds = model.get_initializer(self.onnx_node.input[2]) - else: - thresholds = None - idt = self.get_input_datatype() - # calculate minimum and maximum values of accumulator according to the - # weight values using the bounds derived in https://arxiv.org/abs/2301.13376 + weights = model.get_initializer(self.onnx_node.input[1]) + k_h, k_w = self.get_nodeattr("Kernel") + fm = self.get_nodeattr("Channels") + # put weights into the shape expected by calculate_matvec_accumulator_range + weights = weights.reshape(fm, k_h * k_w).transpose() + # since in the calculation the values of the weight matrix are used, + # for the bipolar case they need to be converted to bipolar + if self.get_nodeattr("binaryXnorMode"): + weights = 2 * weights - 1 + if len(self.onnx_node.input) > 2: + thresholds = model.get_initializer(self.onnx_node.input[2]) + else: + thresholds = None + idt = self.get_input_datatype() + # if runtime-writeable weights, then the values of the weights can + # change and we need to use the worst-case values from the datatypes + if self.get_nodeattr("runtime_writeable_weights"): + wdt = self.get_weight_datatype() + lower_worst = wdt.min() * np.ones_like(weights) + lower_range = calculate_matvec_accumulator_range(lower_worst, idt) + upper_worst = wdt.max() * np.ones_like(weights) + upper_range = calculate_matvec_accumulator_range(upper_worst, idt) + acc_min = min(min(lower_range), min(upper_range)) + acc_max = max(max(upper_range), max(upper_range)) + # if not runtime-writeable weights, then we can calculate the min + # and max values of the accumulation range using knowledge of the + # weights and input data types since they are fixed + else: (acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt) - if thresholds is not None: + # if the thresholds can be used to determine range, then adjust the range + # according to the known values of the thresholds + if thresholds is not None: + threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) + # set threshold datatype (and accumulator datatype implicitly) + min_threshold = thresholds.min() + max_threshold = thresholds.max() + # clip threshold values + clip_upper = None + clip_lower = None + if max_threshold > acc_max + 1: + clip_upper = acc_max + 1 + if min_threshold < acc_min: + clip_lower = acc_min + if (clip_lower is not None) or (clip_upper is not None): + warnings.warn("Clipping some thresholds in %s" % self.onnx_node.name) + thresholds = np.clip(thresholds, clip_lower, clip_upper) + model.set_initializer(self.onnx_node.input[2], thresholds) threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) - # set threshold datatype (and accumulator datatype implicitly) min_threshold = thresholds.min() max_threshold = thresholds.max() - # clip threshold values - clip_upper = None - clip_lower = None - if max_threshold > acc_max + 1: - clip_upper = acc_max + 1 - if min_threshold < acc_min: - clip_lower = acc_min - if (clip_lower is not None) or (clip_upper is not None): - warnings.warn( - "Clipping some thresholds in %s" % self.onnx_node.name - ) - thresholds = np.clip(thresholds, clip_lower, clip_upper) - model.set_initializer(self.onnx_node.input[2], thresholds) - threshold_tensor = self.get_hls_compatible_threshold_tensor( - thresholds - ) - min_threshold = thresholds.min() - max_threshold = thresholds.max() - # get range required by threshold values - tdt_min = min(acc_min, min_threshold) - tdt_max = max(acc_max, max_threshold) - if tdt_min < 0: - if abs(tdt_min) > tdt_max: - tdt = DataType.get_smallest_possible(tdt_min) - else: - tdt = DataType.get_smallest_possible(-tdt_max - 1) + # get range required by threshold values + tdt_min = min(acc_min, min_threshold) + tdt_max = max(acc_max, max_threshold) + if tdt_min < 0: + if abs(tdt_min) > tdt_max: + tdt = DataType.get_smallest_possible(tdt_min) else: - tdt = DataType.get_smallest_possible(tdt_max) - assert np.vectorize(tdt.allowed)( - threshold_tensor - ).all(), "Thresholds in %s can't be expressed with type %s" % ( - self.onnx_node.name, - str(tdt), - ) - self.set_nodeattr("accDataType", tdt.name) + tdt = DataType.get_smallest_possible(-tdt_max - 1) else: - if acc_min < 0: - if abs(acc_min) > acc_max: - adt = DataType.get_smallest_possible(acc_min) - else: - adt = DataType.get_smallest_possible(-acc_max - 1) + tdt = DataType.get_smallest_possible(tdt_max) + assert np.vectorize(tdt.allowed)( + threshold_tensor + ).all(), "Thresholds in %s can't be expressed with type %s" % ( + self.onnx_node.name, + str(tdt), + ) + adt = tdt # Set activation datatype to the threshold datatype + else: + if acc_min < 0: + if abs(acc_min) > acc_max: + adt = DataType.get_smallest_possible(acc_min) else: - adt = DataType.get_smallest_possible(acc_max) - # ensure a datatype divisible by 8-bits in case this is the last node - bw = roundup_to_integer_multiple(adt.bitwidth(), 8) - new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw)) - adt = DataType[new_adt_name] - self.set_nodeattr("accDataType", adt.name) - # for no-activation nodes, output dt = acc dt - self.set_nodeattr("outputDataType", adt.name) + adt = DataType.get_smallest_possible(-acc_max - 1) + else: + adt = DataType.get_smallest_possible(acc_max) + # if this is the last node in the graph, then ensure the datatype is + # divisibly by 8 bits + if model.find_direct_successors(self.onnx_node) is None: + bw = roundup_to_integer_multiple(adt.bitwidth(), 8) + new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw)) + adt = DataType[new_adt_name] + # for no-activation nodes, output dt = acc dt + self.set_nodeattr("outputDataType", adt.name) + self.set_nodeattr("accDataType", adt.name) return DataType[self.get_nodeattr("accDataType")] def minimize_weight_bit_width(self, model): diff --git a/src/finn/util/data_packing.py b/src/finn/util/data_packing.py index 3602b1bdd5d013ee8ce2f6cf156490478f0cc74e..a41fe882e543db3f6809f0bc269b81c1e8a22ab5 100644 --- a/src/finn/util/data_packing.py +++ b/src/finn/util/data_packing.py @@ -220,7 +220,7 @@ def unpack_innermost_dim_from_hex_string( if conv_dtype == DataType["BIPOLAR"]: ar_list = [2 * x - 1 for x in ar_list] # interpret values as signed values - elif dtype.signed(): + elif conv_dtype.signed() and conv_dtype.is_integer(): mask = 2 ** (conv_dtype.bitwidth() - 1) ar_list = [-(x & mask) + (x & ~mask) for x in ar_list] diff --git a/tests/end2end/test_end2end_cybsec_mlp.py b/tests/end2end/test_end2end_cybsec_mlp.py index 86942415b9307654e6afaaa82dc05b009954a710..d2a4d0287fc16d6bf4281be07a6a7ed5027150f1 100644 --- a/tests/end2end/test_end2end_cybsec_mlp.py +++ b/tests/end2end/test_end2end_cybsec_mlp.py @@ -222,7 +222,7 @@ def test_end2end_cybsec_mlp_build(QONNX_export): assert est_cycles_dict["MatrixVectorActivation_1"] == 64 with open(est_res_report, "r") as f: est_res_dict = json.load(f) - assert est_res_dict["total"]["LUT"] == 11360.0 + assert est_res_dict["total"]["LUT"] == 7904.0 assert est_res_dict["total"]["BRAM_18K"] == 36.0 shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build", QONNX_export)) diff --git a/tests/fpgadataflow/test_fpgadataflow_vvau.py b/tests/fpgadataflow/test_fpgadataflow_vvau.py index be1ada59a1ef50cf8d1c9ea26b31ce4956d9f3db..95501078d6dcf82f7aa3f0eb887436e7640dfeae 100644 --- a/tests/fpgadataflow/test_fpgadataflow_vvau.py +++ b/tests/fpgadataflow/test_fpgadataflow_vvau.py @@ -43,9 +43,6 @@ import finn.core.onnx_exec as oxe from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP -from finn.transformation.fpgadataflow.minimize_accumulator_width import ( - MinimizeAccumulatorWidth, -) from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim from finn.transformation.fpgadataflow.prepare_ip import PrepareIP from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim @@ -156,8 +153,6 @@ def _make_single_vvau_modelwrapper( model.set_tensor_datatype("thresh", tdt) model.set_initializer("thresh", T) - # Minimize accumulator width to obtain realistic HLS reports - model = model.transform(MinimizeAccumulatorWidth()) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) diff --git a/tests/fpgadataflow/test_minimize_bit_width.py b/tests/fpgadataflow/test_minimize_bit_width.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4a076a1808932b027f87d10b4d31a400ac1ad5 --- /dev/null +++ b/tests/fpgadataflow/test_minimize_bit_width.py @@ -0,0 +1,320 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +from onnx import TensorProto, helper +from qonnx.core.datatype import BipolarType, DataType, IntType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.util.basic import gen_finn_dt_tensor, roundup_to_integer_multiple +from typing import Optional, Union + +from finn.custom_op.fpgadataflow.matrixvectoractivation import MatrixVectorActivation +from finn.custom_op.fpgadataflow.vectorvectoractivation import VectorVectorActivation +from finn.transformation.fpgadataflow.minimize_accumulator_width import ( + MinimizeAccumulatorWidth, +) +from finn.transformation.fpgadataflow.minimize_weight_bit_width import ( + MinimizeWeightBitWidth, +) + + +def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] = None): + """Creates a toy finn-onnx model for unit testing. The VVAU-MVAU pair is based + on the first pair of MobileNetV1""" + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 32, 32, 288]) + outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 32, 32, 64]) + layer1 = helper.make_node( + "VectorVectorActivation", + ["inp", "params0", "thresh0"] if tdt is not None else ["inp", "params0"], + ["hid"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + PE=1, + Channels=32, + Dim=(32, 32), + Kernel=(3, 3), + inputDataType=idt.name, + outputDataType=idt.name, + weightDataType=wdt.name, + ActVal=tdt.min() if tdt is not None else 0, + noActivation=0 if tdt is not None else 1, + ) + layer2 = helper.make_node( + "MatrixVectorActivation", + ["hid", "params1", "thresh1"] if tdt is not None else ["hid", "params1"], + ["outp"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + MW=32, # matrix_width (num_inputs) + MH=64, # matrix_height (num_outputs) + SIMD=1, + PE=1, + inputDataType=idt.name, + outputDataType=idt.name, + weightDataType=wdt.name, + ActVal=tdt.min() if tdt is not None else 0, + noActivation=0 if tdt is not None else 1, + binaryXnorMode=0, + ) + graph = helper.make_graph( + nodes=[layer1, layer2], name="fclayer_graph", inputs=[inp], outputs=[outp] + ) + + model = helper.make_model(graph, producer_name="fclayer-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", idt) + model.set_tensor_datatype("hid", idt) + model.set_tensor_datatype("params0", wdt) + model.set_tensor_datatype("params1", wdt) + model.set_initializer("params0", gen_finn_dt_tensor(wdt, (32, 1, 3, 3))) + model.set_initializer("params1", gen_finn_dt_tensor(wdt, (32, 64))) + # if the threshold data type is specified, then we need to generate + # some dummy threshold values + if tdt is not None: + model.set_tensor_datatype("thresh0", tdt) + model.set_tensor_datatype("thresh1", tdt) + # Create threshold tensors + n_steps: int = idt.get_num_possible_values() - 1 + thresholds: Optional[np.ndarray] = np.random.randint( + tdt.min(), tdt.max() - 1, (32, n_steps) + ).astype( + np.float32 + ) # generate thresholds for the activations + thresholds = np.sort(thresholds, axis=1) # provide non-decreasing thresholds + model.set_initializer("thresh0", thresholds) + thresholds: Optional[np.ndarray] = np.random.randint( + tdt.min(), tdt.max() - 1, (64, n_steps) + ).astype( + np.float32 + ) # generate thresholds for the activations + thresholds = np.sort(thresholds, axis=1) # provide non-decreasing thresholds + model.set_initializer("thresh1", thresholds) + return model + + +weight_data_types = [ + DataType["INT8"], + DataType["UINT8"], + DataType["INT7"], + DataType["UINT7"], + DataType["INT3"], + DataType["UINT3"], + # DataType["BIPOLAR"], # TODO - add support for bipolar weights + DataType["TERNARY"], +] + + +input_data_types = [ + DataType["INT8"], + DataType["UINT8"], + DataType["INT3"], + DataType["UINT3"], + DataType["BIPOLAR"], + DataType["TERNARY"], +] + + +@pytest.mark.parametrize("wdt", weight_data_types) +@pytest.mark.parametrize("rww", [True, False]) +@pytest.mark.fpgadataflow +def test_minimize_weight_bit_width(wdt: DataType, rww: bool): + """Testing MinimizeWeightBitWidth for VVAU and MVAU. + + :param wdt: (DataType) The data type that we are testing for the weights + :param rww: (bool) Whether or not to use runtime-writeable weights""" + if isinstance(wdt, BipolarType): + # current MinimizeWeightBitWidth sets {-1,1} to INT2, need to check + # for 0 in weights to minimize weight bit width to bipolar + pytest.skip("Not well-supported for this optimization") + + # Create a w8a8 model + def_wdt = DataType["UINT8"] + model = make_unit_test_model(def_wdt, DataType["INT8"]) + + # Create new weights for the model based on wdt + params0 = gen_finn_dt_tensor(wdt, (32, 1, 3, 3)) + params1 = gen_finn_dt_tensor(wdt, (32, 64)) + model.set_initializer("params0", params0) + model.set_initializer("params1", params1) + + # If runtime-writeable weights, specify as a node attribute + for node in model.graph.node: + inst = getCustomOp(node) + if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)): + inst.set_nodeattr("runtime_writeable_weights", int(rww)) + + # Apply the optimization + model = model.transform(MinimizeWeightBitWidth()) + + # Iterate through each node to make sure it functioned properly + for node in model.graph.node: + inst = getCustomOp(node) + if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)): + cur_wdt = DataType[inst.get_nodeattr("weightDataType")] + exp_wdt = def_wdt if rww else wdt + assert cur_wdt.bitwidth() == exp_wdt.bitwidth(), "Mismatched data types" + + +def calculate_accumulator_bit_width( + inst: Union[MatrixVectorActivation, VectorVectorActivation], model: ModelWrapper +) -> Union[DataType, IntType]: + """Calculate the accumulator bit width using the closed-form expressions + derived in `Quantized Neural Networks for Low-Precision Accumulation + with Guaranteed Overflow Avoidance` (2023) by I.Colbert, A. Pappalardo, + J. Petri-Koenig + + :param inst: (HLSCustomOp) The instance of the MVAU or VVAU + :param model: (ModelWrapper) The instance of the whole model + """ + + def phi(x: float) -> float: + return np.log2(1 + pow(2, -x)) + + weights = model.get_initializer(inst.onnx_node.input[1]) + # since in the calculation the values of the weight matrix are used, + # for the bipolar case they need to be converted to bipolar + if inst.get_nodeattr("binaryXnorMode"): + weights = 2 * weights - 1 + # modify the weights based on if the node is a VVAU or MVAU + if isinstance(inst, MatrixVectorActivation): + K = inst.get_nodeattr("MW") # matrix_width = num_inputs + elif isinstance(inst, VectorVectorActivation): + k_h, k_w = inst.get_nodeattr("Kernel") + K = k_h * k_w # size of kernels = num_inputs + fm = inst.get_nodeattr("Channels") + # put weights into the shape expected by calculate_matvec_accumulator_range + weights = weights.reshape(fm, k_h * k_w).transpose() + else: + raise Exception("Considering only MVAU and VVAU currently") + # collect attributes used to determine the accumulator bit width bound + wdt = inst.get_weight_datatype() + idt = inst.get_input_datatype() + rww = inst.get_nodeattr("runtime_writeable_weights") + # if runtime-writeable weights, then use the lower bound on the accumulator bit + # width as determined by the input and weight data types and size of dot product + if rww: + alpha = np.log2(K) + idt.bitwidth() + wdt.bitwidth() - 1.0 - float(idt.signed()) + P = np.ceil(alpha + phi(alpha) + 1.0) + # if not runtime-writable weights, then use the tighter bound on the accumulator + # bit width as determined by the weight values themselves + else: + beta = ( + np.log2(abs(weights).sum(axis=0).max()) + + idt.bitwidth() + - float(idt.signed()) + ) + P = np.ceil(beta + phi(beta) + 1.0) + # if the node is the last in the graph, then round up to the nearest 8 bits + if model.find_direct_successors(inst.onnx_node) is None: + P = roundup_to_integer_multiple(P, 8) + return DataType[f"INT{int(P)}"] + + +thresh_data_types = [ + None, + DataType["INT32"], + DataType["INT24"], + DataType["INT16"], +] + +# Removing unsigned data types fro weights +weight_data_types = [ + DataType["INT8"], + DataType["INT7"], + DataType["INT3"], + # DataType["BIPOLAR"], # TODO - add support for bipolar weights + DataType["TERNARY"], +] + + +@pytest.mark.parametrize("wdt", weight_data_types) +@pytest.mark.parametrize("idt", input_data_types) +@pytest.mark.parametrize("tdt", thresh_data_types) +@pytest.mark.parametrize("rww", [True, False]) +@pytest.mark.fpgadataflow +def test_minimize_accumulator_width( + wdt: DataType, idt: DataType, tdt: DataType, rww: bool +): + """Testing MinimizeAccumulatorWidth for VVAU and MVAU. + + :param wdt: (DataType) The data type that we are testing for the weights + :param idt: (DataType) The data type that we are testing for the activations + :param tdt: (DataType) The data type that we are testing for the thresholds + :param rww: (bool) Whether or not to use runtime-writeable weights""" + if (not wdt.signed()) or isinstance(wdt, BipolarType): + pytest.skip( + "Closed-form accumulator calculation is designed to consider signed weights" + ) + + # Create uniform-precision model + model = make_unit_test_model(wdt, idt, tdt) + def_adt = DataType["INT32"] + + # If runtime-writeable weights, specify as a node attribute + for node in model.graph.node: + inst = getCustomOp(node) + if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)): + inst.set_nodeattr("runtime_writeable_weights", int(rww)) + cur_adt = DataType[inst.get_nodeattr("accDataType")] + assert ( + cur_adt.bitwidth() == def_adt.bitwidth() + ), "Default data type is incorrect" + + # Apply the optimization + model = model.transform(MinimizeAccumulatorWidth()) + + # Iterate through each node to make sure it functioned properly + for node in model.graph.node: + inst = getCustomOp(node) + if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)): + cur_adt = DataType[inst.get_nodeattr("accDataType")] + cur_odt = DataType[inst.get_nodeattr("outputDataType")] + # Calculating expected accumulator bit width using a closed-form expression + # that is a slight over-approximation of the lower bound. The accumulator + # bit width minimization logic in the MVAU and VVAU is exact and should be + # less than or equal to this calculation + exp_adt = calculate_accumulator_bit_width(inst, model) + assert ( + cur_adt.bitwidth() <= exp_adt.bitwidth() + ), "Mismatched accumulation data types" + if model.find_direct_successors(inst.onnx_node) is None: + assert ( + cur_adt.bitwidth() % 8 + ) == 0, "bit width of last node needs to be divisible by 8" + assert ( + cur_adt.bitwidth() == cur_odt.bitwidth() + ), "outputDataType and accDataType should be equal" + else: + assert ( + cur_odt.bitwidth() == idt.bitwidth() + ), "outputDataType should not be changed"