diff --git a/src/finn/transformation/convert_qonnx_to_finn.py b/src/finn/transformation/convert_qonnx_to_finn.py index 143145ee81264edd4a7e77e5beb3cabef5e837c5..4ff37ae1cee94beebe8ef28c81a029c07cf4ed3f 100644 --- a/src/finn/transformation/convert_qonnx_to_finn.py +++ b/src/finn/transformation/convert_qonnx_to_finn.py @@ -95,7 +95,8 @@ class ConvertQuantActToMultiThreshold(Transformation): class FoldQuantWeights(Transformation): """Merges Quant nodes, which are used as weights into the initializer - of the weight tensor.""" + of the weight tensor. + """ def apply(self, model): graph = model.graph @@ -193,11 +194,19 @@ class FoldQuantWeights(Transformation): return (model, graph_modified) +# ToDo: Should these handlers live in their own file? class QuantActBaseHandler(ABC): """Base class for converting quantized activation expressed in the QONNX dialect - to the FINN ONNX dialect.""" + to the FINN ONNX dialect. + :param model: The model on which this handler should operate. + :type model: class: `finn.core.modelwrapper.ModelWrapper` + :param quant_node: The Quant node which a given handler should replace. + :param quant_node_index: The index of the Quant node in the given model. + :type model: `int` + """ def __init__(self, model: ModelWrapper, quant_node, quant_node_index: int): + """Constructor""" super().__init__() self._model = model self._q_node = quant_node @@ -207,35 +216,55 @@ class QuantActBaseHandler(ABC): @classmethod @abstractmethod def valid_predecessor_op_types(self): + """Defines which op types the preceding node is allowed to have for + this type of activation. + """ raise NotImplementedError() @abstractmethod def _check_compatibility(self): + """Check for compatibility with FINN. + There are many more possible combinations of QONNX settings, + than what is supported by FINN. + """ raise NotImplementedError() @abstractmethod def _calculate_act_bias(self): + """Calculate the activation bias, + which is introduced as an Add node behind the MultiTrheshold node. + """ raise NotImplementedError() @abstractmethod def _calculate_thresholds(self): + """Calculate the threshold array for the MultiThreshold node.""" raise NotImplementedError() @abstractmethod def _calculate_act_scale(self): + """Calculate the activation scale, + which is indroduced as a Mul node behind the Add node + for the activation bias. + """ raise NotImplementedError() @abstractmethod def _remove_activation_node(self): + """Remove the activation node in front of the Quant node.""" raise NotImplementedError() def _extract_output_datatype(self): + """Get the output datatype for the MultiThreshold node.""" dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name if "SCALED" in dtype: dtype = dtype.replace("SCALED", "") return dtype def calculate_node_parameters(self): + """Calculate all parameters required for replacing the QONNX style activation + with a FINN style one. + """ out_dtype = self._extract_output_datatype() return { "out_dtype": out_dtype, @@ -245,6 +274,8 @@ class QuantActBaseHandler(ABC): } def replace_quant_node(self): + """Replace the given QONNX style activation with a FINN style one.""" + # Check that we actually support what the user is trying to do self._check_compatibility()