Skip to content
Snippets Groups Projects
Commit f9f6f60f authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Small refactoring for QONNX activation handlers.

parent e7eeed0f
No related branches found
No related tags found
No related merge requests found
......@@ -96,7 +96,7 @@ class QuantActBaseHandler(ABC):
def _extract_output_datatype(self):
"""Get the output datatype for the MultiThreshold node."""
dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name
if "SCALED" in dtype:
if dtype is not None:
dtype = dtype.replace("SCALED", "")
return dtype
......@@ -104,9 +104,8 @@ class QuantActBaseHandler(ABC):
"""Calculate all parameters required for replacing the QONNX style activation
with a FINN style one.
"""
out_dtype = self._extract_output_datatype()
return {
"out_dtype": out_dtype,
"out_dtype": self._extract_output_datatype(),
"thresholds": self._calculate_thresholds(),
"adder_bias": self._calculate_act_bias(),
"mul_scale": self._calculate_act_scale(),
......@@ -159,9 +158,9 @@ class QuantActBaseHandler(ABC):
# If these values are scalar then they can be set as attributes
# of the MultiThreshold node, if not they get inserted as adder and mul nodes
# behind the MultiTrheshold nodes.
bias_compatible = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
scale_compatible = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_compatible and bias_compatible and True:
bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_scalar and bias_scalar and True:
# Get Quant parameters
mul_scale = np.atleast_1d(mul_scale)
# ONNX only accepts 64bit floats as attributes
......@@ -184,7 +183,7 @@ class QuantActBaseHandler(ABC):
# Set bias
zero_bias = False
if bias_compatible:
if bias_scalar:
adder_bias = np.atleast_1d(adder_bias)
# ONNX only accepts 64bit floats as attributes
adder_bias = adder_bias.astype(dtype=np.float64)[0]
......@@ -228,7 +227,7 @@ class QuantActBaseHandler(ABC):
# Set scale
# Insert Mul node
unity_scale = False
if scale_compatible:
if scale_scalar:
mul_scale = np.atleast_1d(mul_scale)
mul_scale = mul_scale.astype(dtype=np.float64)[0]
mul_shape = tuple()
......@@ -369,7 +368,10 @@ class QuantReluHandler(QuantActBaseHandler):
class QuantIdentityHandler(QuantActBaseHandler):
"""Class for converting a quantized identity operation expressed in the QONNX
dialect to the FINN ONNX dialect."""
dialect to the FINN ONNX dialect.
This handler also takes care of quantized HardTanh activations, because
these are equivalent to quantized identity activations.
"""
valid_predecessor_op_types = [
"BatchNormalization",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment