Skip to content
Snippets Groups Projects
Commit 5fc80734 authored by icolbert's avatar icolbert
Browse files

Adding test for MinimizeAccumulatorWidth

parent 68a6f3e1
No related branches found
No related tags found
No related merge requests found
...@@ -37,6 +37,7 @@ from qonnx.util.basic import gen_finn_dt_tensor ...@@ -37,6 +37,7 @@ from qonnx.util.basic import gen_finn_dt_tensor
from finn.custom_op.fpgadataflow.vectorvectoractivation import VectorVectorActivation from finn.custom_op.fpgadataflow.vectorvectoractivation import VectorVectorActivation
from finn.custom_op.fpgadataflow.matrixvectoractivation import MatrixVectorActivation from finn.custom_op.fpgadataflow.matrixvectoractivation import MatrixVectorActivation
from finn.transformation.fpgadataflow.minimize_weight_bit_width import MinimizeWeightBitWidth from finn.transformation.fpgadataflow.minimize_weight_bit_width import MinimizeWeightBitWidth
from finn.transformation.fpgadataflow.minimize_accumulator_width import MinimizeAccumulatorWidth
def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] = None): def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] = None):
...@@ -112,6 +113,17 @@ weight_data_types = [ ...@@ -112,6 +113,17 @@ weight_data_types = [
DataType["TERNARY"], DataType["TERNARY"],
] ]
input_data_types = [
DataType['INT8'],
DataType['UINT8'],
DataType['INT3'],
DataType['UINT3'],
DataType["BIPOLAR"],
DataType["TERNARY"],
]
@pytest.mark.parametrize("wdt", weight_data_types) @pytest.mark.parametrize("wdt", weight_data_types)
@pytest.mark.parametrize("rww", [True, False]) @pytest.mark.parametrize("rww", [True, False])
def test_minimize_weight_bit_width(wdt: DataType, rww: bool): def test_minimize_weight_bit_width(wdt: DataType, rww: bool):
...@@ -146,3 +158,45 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool): ...@@ -146,3 +158,45 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool):
cur_wdt = DataType[inst.get_nodeattr("weightDataType")] cur_wdt = DataType[inst.get_nodeattr("weightDataType")]
exp_wdt = def_wdt if rww else wdt exp_wdt = def_wdt if rww else wdt
assert cur_wdt.bitwidth() == exp_wdt.bitwidth(), "Mismatched data types" assert cur_wdt.bitwidth() == exp_wdt.bitwidth(), "Mismatched data types"
@pytest.mark.parametrize("wdt", weight_data_types)
@pytest.mark.parametrize("adt", input_data_types)
@pytest.mark.parametrize("rww", [True, False])
def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool):
"""Testing MinimizeAccumulatorWidth for VVAU and MVAU.
:param wdt: (DataType) The data type that we are testing for the weights
:param idt: (DataType) The data type that we are testing for the activations
:param rww: (bool) Whether or not to use runtime-writeable weights"""
# Create uniform-precision model
# TODO: add thresholds (tdt) to unit tests
model = make_unit_test_model(wdt, idt)
def_adt = DataType["INT32"]
# If runtime-writeable weights, specify as a node attribute
for node in model.graph.node:
inst = getCustomOp(node)
if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)):
inst.set_nodeattr("runtime_writeable_weights", int(rww))
cur_adt = DataType[inst.get_nodeattr("accDataType")]
assert cur_adt.bitwidth() == def_adt.bitwidth(), "Default data type is incorrect"
# Apply the optimization
model = model.transform(MinimizeAccumulatorWidth())
# Iterate through each node to make sure it functioned properly
for node in model.graph.node:
inst = getCustomOp(node)
if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)):
cur_adt = DataType[inst.get_nodeattr("accDataType")]
cur_odt = DataType[inst.get_nodeattr("accDataType")]
# TODO - figure out how to calculate expected accDataType
# exp_wdt = def_wdt if rww else wdt
# assert cur_adt.bitwidth() == exp_adt.bitwidth(), "Mismatched data types"
if model.find_direct_successors(inst.onnx_node) is None:
assert (cur_adt.bitwidth() % 8) == 0, "bit width of last node needs to be divisible by 8"
assert cur_adt.bitwidth() == cur_odt.bitwidth(), "outputDataType and accDataType should be equal"
else:
assert cur_adt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed"
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment