From 9a99b31141f372417914ed47b73d1ede1b287a02 Mon Sep 17 00:00:00 2001
From: icolbert <Ian.Colbert@amd.com>
Date: Tue, 28 Feb 2023 18:06:52 -0800
Subject: [PATCH] Adding threshold data types for accumulator width unit test

---
 .../streamline/test_minimize_bit_width.py     | 37 +++++++++++++++----
 1 file changed, 29 insertions(+), 8 deletions(-)

diff --git a/tests/transformation/streamline/test_minimize_bit_width.py b/tests/transformation/streamline/test_minimize_bit_width.py
index 221be75da..7cb866c6e 100644
--- a/tests/transformation/streamline/test_minimize_bit_width.py
+++ b/tests/transformation/streamline/test_minimize_bit_width.py
@@ -51,7 +51,7 @@ def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] =
     outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 32, 32, 64])
     layer1 = helper.make_node(
         "VectorVectorActivation",
-        ["inp", "params0", "thresh"] if tdt is not None else ["inp", "params0"],
+        ["inp", "params0", "thresh0"] if tdt is not None else ["inp", "params0"],
         ["hid"],
         domain="finn.custom_op.fpgadataflow",
         backend="fpgadataflow",
@@ -67,7 +67,7 @@ def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] =
     )
     layer2 = helper.make_node(
         "MatrixVectorActivation",
-        ["hid", "params1", "thresh"] if tdt is not None else ["hid", "params1"],
+        ["hid", "params1", "thresh1"] if tdt is not None else ["hid", "params1"],
         ["outp"],
         domain="finn.custom_op.fpgadataflow",
         backend="fpgadataflow",
@@ -100,9 +100,21 @@ def make_unit_test_model(wdt: DataType, idt: DataType, tdt: Optional[DataType] =
     model.set_initializer("params1",
                           gen_finn_dt_tensor(wdt, (32, 64))
     )
+    # if the threshold data type is specified, then we need to generate
+    # some dummy threshold values
     if tdt is not None:
-        model.set_tensor_datatype("thresh", tdt)
-        # model.set_initializer("thresh", thresholds)
+        model.set_tensor_datatype("thresh0", tdt)
+        model.set_tensor_datatype("thresh1", tdt)
+        # Create threshold tensors
+        n_steps: int = idt.get_num_possible_values() - 1
+        thresholds: Optional[np.ndarray] = np.random.randint(tdt.min(), tdt.max() - 1, \
+            (32, n_steps)).astype(np.float32) # generate thresholds for the activations
+        thresholds = np.sort(thresholds, axis=1) # provide non-decreasing thresholds
+        model.set_initializer("thresh0", thresholds)
+        thresholds: Optional[np.ndarray] = np.random.randint(tdt.min(), tdt.max() - 1, \
+            (64, n_steps)).astype(np.float32) # generate thresholds for the activations
+        thresholds = np.sort(thresholds, axis=1) # provide non-decreasing thresholds
+        model.set_initializer("thresh1", thresholds)
     return model
 
 
@@ -170,7 +182,7 @@ def calculate_accumulator_bit_width(
         inst: Union[MatrixVectorActivation, VectorVectorActivation],
         model: ModelWrapper
     ) -> Union[DataType, IntType]:
-    """Calculate the accumulator bit width use the closed-form expressions
+    """Calculate the accumulator bit width using the closed-form expressions
     derived in `Quantized Neural Networks for Low-Precision Accumulation
     with Guaranteed Overflow Avoidance` (2023) by I.Colbert, A. Pappalardo,
     J. Petri-Koenig
@@ -217,21 +229,30 @@ def calculate_accumulator_bit_width(
     return DataType[f"INT{int(P)}"]
 
 
+thresh_data_types = [
+    None,
+    DataType['INT32'],
+    DataType['INT24'],
+    DataType['INT16'],
+]
+
+
 @pytest.mark.parametrize("wdt", weight_data_types)
 @pytest.mark.parametrize("idt", input_data_types)
+@pytest.mark.parametrize("tdt", thresh_data_types)
 @pytest.mark.parametrize("rww", [True, False])
-def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool):
+def test_minimize_accumulator_width(wdt: DataType, idt: DataType, tdt: DataType, rww: bool):
     """Testing MinimizeAccumulatorWidth for VVAU and MVAU.
     
     :param wdt: (DataType) The data type that we are testing for the weights
     :param idt: (DataType) The data type that we are testing for the activations
+    :param tdt: (DataType) The data type that we are testing for the thresholds
     :param rww: (bool) Whether or not to use runtime-writeable weights"""
     if not wdt.signed():
         pytest.skip("Closed-form accumulator calculation is designed to consider only signed weights")
 
     # Create uniform-precision model
-    # TODO: add thresholds (tdt) to unit tests
-    model = make_unit_test_model(wdt, idt)
+    model = make_unit_test_model(wdt, idt, tdt)
     def_adt = DataType["INT32"]
 
     # If runtime-writeable weights, specify as a node attribute
-- 
GitLab