Skip to content
Snippets Groups Projects
Commit f9f6f60f authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Small refactoring for QONNX activation handlers.

parent e7eeed0f
No related branches found
No related tags found
No related merge requests found
...@@ -96,7 +96,7 @@ class QuantActBaseHandler(ABC): ...@@ -96,7 +96,7 @@ class QuantActBaseHandler(ABC):
def _extract_output_datatype(self): def _extract_output_datatype(self):
"""Get the output datatype for the MultiThreshold node.""" """Get the output datatype for the MultiThreshold node."""
dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name
if "SCALED" in dtype: if dtype is not None:
dtype = dtype.replace("SCALED", "") dtype = dtype.replace("SCALED", "")
return dtype return dtype
...@@ -104,9 +104,8 @@ class QuantActBaseHandler(ABC): ...@@ -104,9 +104,8 @@ class QuantActBaseHandler(ABC):
"""Calculate all parameters required for replacing the QONNX style activation """Calculate all parameters required for replacing the QONNX style activation
with a FINN style one. with a FINN style one.
""" """
out_dtype = self._extract_output_datatype()
return { return {
"out_dtype": out_dtype, "out_dtype": self._extract_output_datatype(),
"thresholds": self._calculate_thresholds(), "thresholds": self._calculate_thresholds(),
"adder_bias": self._calculate_act_bias(), "adder_bias": self._calculate_act_bias(),
"mul_scale": self._calculate_act_scale(), "mul_scale": self._calculate_act_scale(),
...@@ -159,9 +158,9 @@ class QuantActBaseHandler(ABC): ...@@ -159,9 +158,9 @@ class QuantActBaseHandler(ABC):
# If these values are scalar then they can be set as attributes # If these values are scalar then they can be set as attributes
# of the MultiThreshold node, if not they get inserted as adder and mul nodes # of the MultiThreshold node, if not they get inserted as adder and mul nodes
# behind the MultiTrheshold nodes. # behind the MultiTrheshold nodes.
bias_compatible = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
scale_compatible = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_compatible and bias_compatible and True: if scale_scalar and bias_scalar and True:
# Get Quant parameters # Get Quant parameters
mul_scale = np.atleast_1d(mul_scale) mul_scale = np.atleast_1d(mul_scale)
# ONNX only accepts 64bit floats as attributes # ONNX only accepts 64bit floats as attributes
...@@ -184,7 +183,7 @@ class QuantActBaseHandler(ABC): ...@@ -184,7 +183,7 @@ class QuantActBaseHandler(ABC):
# Set bias # Set bias
zero_bias = False zero_bias = False
if bias_compatible: if bias_scalar:
adder_bias = np.atleast_1d(adder_bias) adder_bias = np.atleast_1d(adder_bias)
# ONNX only accepts 64bit floats as attributes # ONNX only accepts 64bit floats as attributes
adder_bias = adder_bias.astype(dtype=np.float64)[0] adder_bias = adder_bias.astype(dtype=np.float64)[0]
...@@ -228,7 +227,7 @@ class QuantActBaseHandler(ABC): ...@@ -228,7 +227,7 @@ class QuantActBaseHandler(ABC):
# Set scale # Set scale
# Insert Mul node # Insert Mul node
unity_scale = False unity_scale = False
if scale_compatible: if scale_scalar:
mul_scale = np.atleast_1d(mul_scale) mul_scale = np.atleast_1d(mul_scale)
mul_scale = mul_scale.astype(dtype=np.float64)[0] mul_scale = mul_scale.astype(dtype=np.float64)[0]
mul_shape = tuple() mul_shape = tuple()
...@@ -369,7 +368,10 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -369,7 +368,10 @@ class QuantReluHandler(QuantActBaseHandler):
class QuantIdentityHandler(QuantActBaseHandler): class QuantIdentityHandler(QuantActBaseHandler):
"""Class for converting a quantized identity operation expressed in the QONNX """Class for converting a quantized identity operation expressed in the QONNX
dialect to the FINN ONNX dialect.""" dialect to the FINN ONNX dialect.
This handler also takes care of quantized HardTanh activations, because
these are equivalent to quantized identity activations.
"""
valid_predecessor_op_types = [ valid_predecessor_op_types = [
"BatchNormalization", "BatchNormalization",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment