From f9f6f60fad59d061f9075c7ae2de4c757b121522 Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Fri, 8 Oct 2021 15:14:15 +0100 Subject: [PATCH] Small refactoring for QONNX activation handlers. --- .../qonnx/qonnx_activation_handlers.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 1e1ad7184..02c156e14 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -96,7 +96,7 @@ class QuantActBaseHandler(ABC): def _extract_output_datatype(self): """Get the output datatype for the MultiThreshold node.""" dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name - if "SCALED" in dtype: + if dtype is not None: dtype = dtype.replace("SCALED", "") return dtype @@ -104,9 +104,8 @@ class QuantActBaseHandler(ABC): """Calculate all parameters required for replacing the QONNX style activation with a FINN style one. """ - out_dtype = self._extract_output_datatype() return { - "out_dtype": out_dtype, + "out_dtype": self._extract_output_datatype(), "thresholds": self._calculate_thresholds(), "adder_bias": self._calculate_act_bias(), "mul_scale": self._calculate_act_scale(), @@ -159,9 +158,9 @@ class QuantActBaseHandler(ABC): # If these values are scalar then they can be set as attributes # of the MultiThreshold node, if not they get inserted as adder and mul nodes # behind the MultiTrheshold nodes. - bias_compatible = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 - scale_compatible = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 - if scale_compatible and bias_compatible and True: + bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 + scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 + if scale_scalar and bias_scalar and True: # Get Quant parameters mul_scale = np.atleast_1d(mul_scale) # ONNX only accepts 64bit floats as attributes @@ -184,7 +183,7 @@ class QuantActBaseHandler(ABC): # Set bias zero_bias = False - if bias_compatible: + if bias_scalar: adder_bias = np.atleast_1d(adder_bias) # ONNX only accepts 64bit floats as attributes adder_bias = adder_bias.astype(dtype=np.float64)[0] @@ -228,7 +227,7 @@ class QuantActBaseHandler(ABC): # Set scale # Insert Mul node unity_scale = False - if scale_compatible: + if scale_scalar: mul_scale = np.atleast_1d(mul_scale) mul_scale = mul_scale.astype(dtype=np.float64)[0] mul_shape = tuple() @@ -369,7 +368,10 @@ class QuantReluHandler(QuantActBaseHandler): class QuantIdentityHandler(QuantActBaseHandler): """Class for converting a quantized identity operation expressed in the QONNX - dialect to the FINN ONNX dialect.""" + dialect to the FINN ONNX dialect. + This handler also takes care of quantized HardTanh activations, because + these are equivalent to quantized identity activations. + """ valid_predecessor_op_types = [ "BatchNormalization", -- GitLab