From 96179ebe99f0444a6df055bbeb3bd6e7180df939 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Fri, 4 Sep 2020 18:04:55 +0100
Subject: [PATCH] [fpgadataflow] Fix tensor element bitwidth missmatch for
 labelselect_batch and channelwise_op_batch

---
 .../fpgadataflow/channelwise_op_batch.py      | 65 ++++++++++++++++++-
 .../fpgadataflow/labelselect_batch.py         |  5 ++
 2 files changed, 67 insertions(+), 3 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py b/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
index d8e74a4d1..88b55aaec 100644
--- a/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
+++ b/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
@@ -41,6 +41,8 @@ from finn.util.data_packing import (
 )
 from . import templates
 
+import warnings
+
 # ONNX i/o tensor shape assumptions for channelwise ops:
 # input 0 is the input tensor, shape (..., NumChannels)
 # input 1 is the channelwise parameter tensor, shape (NumChannels, params_per_channel)
@@ -48,6 +50,37 @@ from . import templates
 # the ... here can be any shape (representing groups of vectors)
 
 
+def get_smallest_possible(vals):
+    """Returns smallest (fewest bits) possible DataType that can represent
+    value. Prefers unsigned integers where possible."""
+    vals = np.array(vals)
+    for v in vals:
+        assert int(v) == v, "Error float value"
+
+    for k in DataType.__members__:
+        dt = DataType[k]
+
+        if dt in [DataType.BIPOLAR, DataType.TERNARY, DataType.FLOAT32]:
+            # not currently supported
+            continue
+
+        if (dt.min() <= vals).all() and (vals <= dt.max()).all():
+            return dt
+
+    warnings.warn(
+        """InferChannelwiseLinearLayer: Output values may not be
+    representable with supported data types.
+    Setting maximum width data type available.
+    This will lead to errors if there are no constrains on the input
+    """
+    )
+
+    if (0 <= vals).all():
+        return DataType.UINT64
+    else:
+        return DataType.INT64
+
+
 class ChannelwiseOp_Batch(HLSCustomOp):
     """Class that corresponds to finn-hls Thresholding_Batch function.
     It can implement a variety of channel-wise parametrized operations,
@@ -109,10 +142,36 @@ class ChannelwiseOp_Batch(HLSCustomOp):
     def infer_node_datatype(self, model):
         node = self.onnx_node
         # check input datatype against property
-        idt_name = self.get_input_datatype().name
+        idt = model.get_tensor_datatype(node.input[0])
+
         exp_idt_name = self.get_nodeattr("inputDataType")
-        assert exp_idt_name == idt_name, "Bad input DataType for ChannelwiseOp layer"
-        # TODO: dynamically infer/update odt based on idt as done in ConvertToHLSLayers?
+        if exp_idt_name != idt.name:
+            func = self.get_nodeattr("Func")
+            assert func in ["add", "mul"], "Bad input DataType for ChannelwiseOp layer"
+
+            self.set_nodeattr("inputDataType", idt.name)
+            # update the func in ['add','mul'] cases
+
+            # get parameter ranges
+            param = model.get_initializer(node.input[1])
+            param_min = min(param.flatten())
+            param_max = max(param.flatten())
+
+            # set function and determine output data type
+            if func == "add":
+                out_min = idt.min() + param_min
+                out_max = idt.max() + param_max
+                odt = get_smallest_possible([out_min, out_max])
+            elif func == "mul":
+                possible_limits = []
+                possible_limits += [idt.min() * param_min]
+                possible_limits += [idt.min() * param_max]
+                possible_limits += [idt.max() * param_min]
+                possible_limits += [idt.max() * param_max]
+                odt = get_smallest_possible(possible_limits)
+
+            self.set_nodeattr("outputDataType", odt.name)
+
         # set output datatype from property
         odt = self.get_output_datatype()
         model.set_tensor_datatype(node.output[0], odt)
diff --git a/src/finn/custom_op/fpgadataflow/labelselect_batch.py b/src/finn/custom_op/fpgadataflow/labelselect_batch.py
index f61fbf12d..c6598a30e 100644
--- a/src/finn/custom_op/fpgadataflow/labelselect_batch.py
+++ b/src/finn/custom_op/fpgadataflow/labelselect_batch.py
@@ -113,6 +113,11 @@ class LabelSelect_Batch(HLSCustomOp):
         )
 
     def infer_node_datatype(self, model):
+        node = self.onnx_node
+        # check input datatype against property
+        idt = model.get_tensor_datatype(node.input[0])
+        self.set_nodeattr("inputDataType", idt.name)
+
         odt = self.get_output_datatype()
         model.set_tensor_datatype(self.onnx_node.output[0], odt)
 
-- 
GitLab