From f9f6f60fad59d061f9075c7ae2de4c757b121522 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Fri, 8 Oct 2021 15:14:15 +0100
Subject: [PATCH] Small refactoring for QONNX activation handlers.

---
 .../qonnx/qonnx_activation_handlers.py        | 20 ++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
index 1e1ad7184..02c156e14 100644
--- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py
+++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
@@ -96,7 +96,7 @@ class QuantActBaseHandler(ABC):
     def _extract_output_datatype(self):
         """Get the output datatype for the MultiThreshold node."""
         dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name
-        if "SCALED" in dtype:
+        if dtype is not None:
             dtype = dtype.replace("SCALED", "")
         return dtype
 
@@ -104,9 +104,8 @@ class QuantActBaseHandler(ABC):
         """Calculate all parameters required for replacing the QONNX style activation
         with a FINN style one.
         """
-        out_dtype = self._extract_output_datatype()
         return {
-            "out_dtype": out_dtype,
+            "out_dtype": self._extract_output_datatype(),
             "thresholds": self._calculate_thresholds(),
             "adder_bias": self._calculate_act_bias(),
             "mul_scale": self._calculate_act_scale(),
@@ -159,9 +158,9 @@ class QuantActBaseHandler(ABC):
         # If these values are scalar then they can be set as attributes
         # of the MultiThreshold node, if not they get inserted as adder and mul nodes
         # behind the MultiTrheshold nodes.
-        bias_compatible = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
-        scale_compatible = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
-        if scale_compatible and bias_compatible and True:
+        bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
+        scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
+        if scale_scalar and bias_scalar and True:
             # Get Quant parameters
             mul_scale = np.atleast_1d(mul_scale)
             # ONNX only accepts 64bit floats as attributes
@@ -184,7 +183,7 @@ class QuantActBaseHandler(ABC):
 
             # Set bias
             zero_bias = False
-            if bias_compatible:
+            if bias_scalar:
                 adder_bias = np.atleast_1d(adder_bias)
                 # ONNX only accepts 64bit floats as attributes
                 adder_bias = adder_bias.astype(dtype=np.float64)[0]
@@ -228,7 +227,7 @@ class QuantActBaseHandler(ABC):
             # Set scale
             # Insert Mul node
             unity_scale = False
-            if scale_compatible:
+            if scale_scalar:
                 mul_scale = np.atleast_1d(mul_scale)
                 mul_scale = mul_scale.astype(dtype=np.float64)[0]
                 mul_shape = tuple()
@@ -369,7 +368,10 @@ class QuantReluHandler(QuantActBaseHandler):
 
 class QuantIdentityHandler(QuantActBaseHandler):
     """Class for converting a quantized identity operation expressed in the QONNX
-    dialect to the FINN ONNX dialect."""
+    dialect to the FINN ONNX dialect.
+    This handler also takes care of quantized HardTanh activations, because
+    these are equivalent to quantized identity activations.
+    """
 
     valid_predecessor_op_types = [
         "BatchNormalization",
-- 
GitLab