Skip to content
Snippets Groups Projects
Commit fcfeb026 authored by icolbert's avatar icolbert
Browse files

Updating checks in minimize_accumulator_width

parent 72798e10
No related branches found
No related tags found
No related merge requests found
...@@ -27,12 +27,16 @@ ...@@ -27,12 +27,16 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest import pytest
from typing import Optional import numpy as np
from typing import Optional, Union
from onnx import TensorProto, helper from onnx import TensorProto, helper
from qonnx.core.datatype import DataType, IntType from qonnx.core.datatype import DataType, IntType
from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp from qonnx.custom_op.registry import getCustomOp
from qonnx.util.basic import gen_finn_dt_tensor from qonnx.util.basic import (
gen_finn_dt_tensor,
roundup_to_integer_multiple
)
from finn.custom_op.fpgadataflow.vectorvectoractivation import VectorVectorActivation from finn.custom_op.fpgadataflow.vectorvectoractivation import VectorVectorActivation
from finn.custom_op.fpgadataflow.matrixvectoractivation import MatrixVectorActivation from finn.custom_op.fpgadataflow.matrixvectoractivation import MatrixVectorActivation
...@@ -109,7 +113,9 @@ weight_data_types = [ ...@@ -109,7 +113,9 @@ weight_data_types = [
DataType['UINT7'], DataType['UINT7'],
DataType['INT3'], DataType['INT3'],
DataType['UINT3'], DataType['UINT3'],
# DataType["BIPOLAR"], # TODO - investigate bipolar weights # TODO - current MinimizeWeightBitWidth sets {-1,1} to INT2, need to check
# for 0 in weights to minimize weight bit width to bipolar
# DataType["BIPOLAR"],
DataType["TERNARY"], DataType["TERNARY"],
] ]
...@@ -160,6 +166,57 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool): ...@@ -160,6 +166,57 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool):
assert cur_wdt.bitwidth() == exp_wdt.bitwidth(), "Mismatched data types" assert cur_wdt.bitwidth() == exp_wdt.bitwidth(), "Mismatched data types"
def calculate_accumulator_bit_width(
inst: Union[MatrixVectorActivation, VectorVectorActivation],
model: ModelWrapper
) -> Union[DataType, IntType]:
"""Calculate the accumulator bit width use the closed-form expressions
derived in `Quantized Neural Networks for Low-Precision Accumulation
with Guaranteed Overflow Avoidance` (2023) by I.Colbert, A. Pappalardo,
J. Petri-Koenig
:param inst: (HLSCustomOp) The instance of the MVAU or VVAU
:param model: (ModelWrapper) The instance of the whole model
"""
def phi(x: float) -> float:
return np.log2(1 + pow(2, -x))
weights = model.get_initializer(inst.onnx_node.input[1])
# since in the calculation the values of the weight matrix are used,
# for the bipolar case they need to be converted to bipolar
if inst.get_nodeattr("binaryXnorMode"):
weights = 2 * weights - 1
# modify the weights based on if the node is a VVAU or MVAU
if isinstance(inst, MatrixVectorActivation):
K = inst.get_nodeattr("MW") # matrix_width = num_inputs
elif isinstance(inst, VectorVectorActivation):
k_h, k_w = inst.get_nodeattr("Kernel")
K = k_h * k_w # size of kernels = num_inputs
fm = inst.get_nodeattr("Channels")
# put weights into the shape expected by calculate_matvec_accumulator_range
weights = weights.reshape(fm, k_h * k_w).transpose()
else:
raise Exception("Considering only MVAU and VVAU currently")
# collect attributes used to determine the accumulator bit width bound
wdt = inst.get_weight_datatype()
idt = inst.get_input_datatype()
rww = inst.get_nodeattr("runtime_writeable_weights")
# if runtime-writeable weights, then use the lower bound on the accumulator bit
# width as determined by the input and weight data types and size of dot product
if rww:
alpha = np.log2(K) + idt.bitwidth() + wdt.bitwidth() - 1. - float(idt.signed())
P = np.ceil(alpha + phi(alpha) + 1.)
# if not runtime-writable weights, then use the tighter bound on the accumulator
# bit width as determined by the weight values themselves
else:
beta = np.log2(abs(weights).sum(axis=0).max()) + idt.bitwidth() - float(idt.signed())
P = np.ceil(beta + phi(beta) + 1.)
# if the node is the last in the graph, then round up to the nearest 8 bits
if model.find_direct_successors(inst.onnx_node) is None:
P = roundup_to_integer_multiple(P, 8)
return DataType[f"INT{int(P)}"]
@pytest.mark.parametrize("wdt", weight_data_types) @pytest.mark.parametrize("wdt", weight_data_types)
@pytest.mark.parametrize("idt", input_data_types) @pytest.mark.parametrize("idt", input_data_types)
@pytest.mark.parametrize("rww", [True, False]) @pytest.mark.parametrize("rww", [True, False])
...@@ -169,6 +226,8 @@ def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool): ...@@ -169,6 +226,8 @@ def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool):
:param wdt: (DataType) The data type that we are testing for the weights :param wdt: (DataType) The data type that we are testing for the weights
:param idt: (DataType) The data type that we are testing for the activations :param idt: (DataType) The data type that we are testing for the activations
:param rww: (bool) Whether or not to use runtime-writeable weights""" :param rww: (bool) Whether or not to use runtime-writeable weights"""
if not wdt.signed():
pytest.skip("Closed-form accumulator calculation is designed to consider only signed weights")
# Create uniform-precision model # Create uniform-precision model
# TODO: add thresholds (tdt) to unit tests # TODO: add thresholds (tdt) to unit tests
...@@ -192,9 +251,12 @@ def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool): ...@@ -192,9 +251,12 @@ def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool):
if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)): if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)):
cur_adt = DataType[inst.get_nodeattr("accDataType")] cur_adt = DataType[inst.get_nodeattr("accDataType")]
cur_odt = DataType[inst.get_nodeattr("outputDataType")] cur_odt = DataType[inst.get_nodeattr("outputDataType")]
# TODO - figure out how to calculate expected accDataType # Calculating expected accumulator bit width using a closed-form expression
# exp_wdt = def_wdt if rww else wdt # that is a slight over-approximation of the lower bound. The accumulator
# assert cur_adt.bitwidth() == exp_adt.bitwidth(), "Mismatched data types" # bit width minimization logic in the MVAU and VVAU is exact and should be
# less than or equal to this calculation
exp_adt = calculate_accumulator_bit_width(inst, model)
assert cur_adt.bitwidth() <= exp_adt.bitwidth(), "Mismatched accumulation data types"
if model.find_direct_successors(inst.onnx_node) is None: if model.find_direct_successors(inst.onnx_node) is None:
assert (cur_adt.bitwidth() % 8) == 0, "bit width of last node needs to be divisible by 8" assert (cur_adt.bitwidth() % 8) == 0, "bit width of last node needs to be divisible by 8"
assert cur_adt.bitwidth() == cur_odt.bitwidth(), "outputDataType and accDataType should be equal" assert cur_adt.bitwidth() == cur_odt.bitwidth(), "outputDataType and accDataType should be equal"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment