From bb5a4a425d860740535affc043d4191751fa6eea Mon Sep 17 00:00:00 2001
From: Lucian Petrica <lucianp@xilinx.com>
Date: Mon, 27 Jul 2020 08:57:19 +0000
Subject: [PATCH] Bitwidth optimizations to accumulators and thresholds,
 through clipping of too-large threshold values and explicit calculation of
 the bitwidth

---
 .../fpgadataflow/streamingfclayer_batch.py    | 37 +++++++++++++++++--
 .../fpgadataflow/thresholding_batch.py        | 17 ++++++++-
 .../streamline/round_thresholds.py            | 12 +++++-
 3 files changed, 61 insertions(+), 5 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 9c3bd3ac8..686975c8c 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -594,7 +594,6 @@ class StreamingFCLayer_Batch(HLSCustomOp):
             thresholds = model.get_initializer(self.onnx_node.input[2])
             if thresholds is not None:
                 threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
-                tdt = DataType.INT32
                 # use UINT32 threshold export for bipolar times bipolar
                 inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
                 wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
@@ -604,8 +603,40 @@ class StreamingFCLayer_Batch(HLSCustomOp):
                 bin_xnor_mode = self.get_nodeattr("binaryXnorMode") == 1
                 inp_is_bipolar = inp_is_bipolar or (inp_is_binary and bin_xnor_mode)
                 wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode)
-                if inp_is_bipolar and wt_is_bipolar:
-                    tdt = DataType.UINT32
+                # set threshold datatype (and accumulator datatype implicitly)
+                min_threshold = thresholds.min()
+                max_threshold = thresholds.max()
+                min_weight = weights.min()
+                max_weight = weights.max()
+                perceptive_field_elems = self.get_nodeattr("MW")
+                min_input = self.get_input_datatype().min()
+                max_input = self.get_input_datatype().max()
+                # calculate minimum and maximum values of accumulator
+                # assume inputs span the whole range of the input datatype
+                acc_min = perceptive_field_elems * min(
+                    min_weight * max_input,
+                    min_weight * min_input,
+                    max_weight * max_input,
+                    max_weight * min_input,
+                )
+                acc_max = perceptive_field_elems * max(
+                    min_weight * max_input,
+                    min_weight * min_input,
+                    max_weight * max_input,
+                    max_weight * min_input,
+                )
+
+                # get range required by threshold values
+                tdt_min = min(acc_min, min_threshold)
+                tdt_max = max(acc_max, max_threshold)
+                if tdt_min < 0:
+                    if abs(tdt_min) > tdt_max:
+                        tdt = DataType.get_smallest_possible(tdt_min)
+                    else:
+                        tdt = DataType.get_smallest_possible(0 - tdt_max)
+                else:
+                    tdt = DataType.get_smallest_possible(tdt_max)
+
                 thresholds_hls_code = numpy_to_hls_code(
                     threshold_tensor, tdt, "thresholds", False, True
                 )
diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
index fa33c7021..cf19a2534 100644
--- a/src/finn/custom_op/fpgadataflow/thresholding_batch.py
+++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
@@ -279,7 +279,22 @@ class Thresholding_Batch(HLSCustomOp):
         thresholds = model.get_initializer(self.onnx_node.input[1])
 
         threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
-        tdt = DataType.INT32
+
+        min_threshold = thresholds.min()
+        max_threshold = thresholds.max()
+        min_input = self.get_input_datatype().min()
+        max_input = self.get_input_datatype().max()
+        # get range required by threshold values
+        tdt_min = min(min_input, min_threshold)
+        tdt_max = max(max_input, max_threshold)
+        if tdt_min < 0:
+            if abs(tdt_min) > tdt_max:
+                tdt = DataType.get_smallest_possible(tdt_min)
+            else:
+                tdt = DataType.get_smallest_possible(0 - tdt_max - 1)
+        else:
+            tdt = DataType.get_smallest_possible(tdt_max)
+
         thresholds_hls_code = numpy_to_hls_code(
             threshold_tensor, tdt, "thresholds", False, True
         )
diff --git a/src/finn/transformation/streamline/round_thresholds.py b/src/finn/transformation/streamline/round_thresholds.py
index c33281d85..8626ef406 100644
--- a/src/finn/transformation/streamline/round_thresholds.py
+++ b/src/finn/transformation/streamline/round_thresholds.py
@@ -51,10 +51,20 @@ class RoundAndClipThresholds(Transformation):
                     model.set_tensor_datatype(n.input[1], idtype)
                     graph_modified = True
                 if idtype.is_integer() and not idtype.signed() and (Tnew < 0).any():
-                    # clip any negative thresholds
+                    # clip any negative thresholds if input is unsigned
                     Tnew = np.clip(Tnew, 0, None)
                     model.set_initializer(n.input[1], Tnew)
                     # use same datatype as inputs for thresholds
                     model.set_tensor_datatype(n.input[1], idtype)
                     graph_modified = True
+                if idtype.is_integer() and (
+                    (Tnew < (idtype.min() - 1)).any()
+                    or (Tnew > (idtype.max() + 1)).any()
+                ):
+                    # clip any large thresholds to input range + 1
+                    Tnew = np.clip(Tnew, idtype.min() - 1, idtype.max() + 1)
+                    model.set_initializer(n.input[1], Tnew)
+                    # use same datatype as inputs for thresholds
+                    model.set_tensor_datatype(n.input[1], idtype)
+                    graph_modified = True
         return (model, graph_modified)
-- 
GitLab