Skip to content
Snippets Groups Projects
Commit 9a99b311 authored by icolbert's avatar icolbert
Browse files

Adding threshold data types for accumulator width unit test

parent fcfeb026
No related branches found
No related tags found
No related merge requests found
......@@ -51,7 +51,7 @@ def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] =
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 32, 32, 64])
layer1 = helper.make_node(
"VectorVectorActivation",
["inp", "params0", "thresh"] if tdt is not None else ["inp", "params0"],
["inp", "params0", "thresh0"] if tdt is not None else ["inp", "params0"],
["hid"],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
......@@ -67,7 +67,7 @@ def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] =
)
layer2 = helper.make_node(
"MatrixVectorActivation",
["hid", "params1", "thresh"] if tdt is not None else ["hid", "params1"],
["hid", "params1", "thresh1"] if tdt is not None else ["hid", "params1"],
["outp"],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
......@@ -100,9 +100,21 @@ def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] =
model.set_initializer("params1",
gen_finn_dt_tensor(wdt, (32, 64))
)
# if the threshold data type is specified, then we need to generate
# some dummy threshold values
if tdt is not None:
model.set_tensor_datatype("thresh", tdt)
# model.set_initializer("thresh", thresholds)
model.set_tensor_datatype("thresh0", tdt)
model.set_tensor_datatype("thresh1", tdt)
# Create threshold tensors
n_steps: int = idt.get_num_possible_values() - 1
thresholds: Optional[np.ndarray] = np.random.randint(tdt.min(), tdt.max() - 1, \
(32, n_steps)).astype(np.float32) # generate thresholds for the activations
thresholds = np.sort(thresholds, axis=1) # provide non-decreasing thresholds
model.set_initializer("thresh0", thresholds)
thresholds: Optional[np.ndarray] = np.random.randint(tdt.min(), tdt.max() - 1, \
(64, n_steps)).astype(np.float32) # generate thresholds for the activations
thresholds = np.sort(thresholds, axis=1) # provide non-decreasing thresholds
model.set_initializer("thresh1", thresholds)
return model
......@@ -170,7 +182,7 @@ def calculate_accumulator_bit_width(
inst: Union[MatrixVectorActivation, VectorVectorActivation],
model: ModelWrapper
) -> Union[DataType, IntType]:
"""Calculate the accumulator bit width use the closed-form expressions
"""Calculate the accumulator bit width using the closed-form expressions
derived in `Quantized Neural Networks for Low-Precision Accumulation
with Guaranteed Overflow Avoidance` (2023) by I.Colbert, A. Pappalardo,
J. Petri-Koenig
......@@ -217,21 +229,30 @@ def calculate_accumulator_bit_width(
return DataType[f"INT{int(P)}"]
thresh_data_types = [
None,
DataType['INT32'],
DataType['INT24'],
DataType['INT16'],
]
@pytest.mark.parametrize("wdt", weight_data_types)
@pytest.mark.parametrize("idt", input_data_types)
@pytest.mark.parametrize("tdt", thresh_data_types)
@pytest.mark.parametrize("rww", [True, False])
def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool):
def test_minimize_accumulator_width(wdt: DataType, idt: DataType, tdt: DataType, rww: bool):
"""Testing MinimizeAccumulatorWidth for VVAU and MVAU.
:param wdt: (DataType) The data type that we are testing for the weights
:param idt: (DataType) The data type that we are testing for the activations
:param tdt: (DataType) The data type that we are testing for the thresholds
:param rww: (bool) Whether or not to use runtime-writeable weights"""
if not wdt.signed():
pytest.skip("Closed-form accumulator calculation is designed to consider only signed weights")
# Create uniform-precision model
# TODO: add thresholds (tdt) to unit tests
model = make_unit_test_model(wdt, idt)
model = make_unit_test_model(wdt, idt, tdt)
def_adt = DataType["INT32"]
# If runtime-writeable weights, specify as a node attribute
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment