diff --git a/src/finn/transformation/convert_qonnx_to_finn.py b/src/finn/transformation/convert_qonnx_to_finn.py index ef7ffe331fa0330a3e4487f4c4a298ab468f0741..143145ee81264edd4a7e77e5beb3cabef5e837c5 100644 --- a/src/finn/transformation/convert_qonnx_to_finn.py +++ b/src/finn/transformation/convert_qonnx_to_finn.py @@ -10,16 +10,6 @@ from finn.transformation.base import Transformation from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_shapes import InferShapes -allowed_identity_predecessor = [ - "BatchNormalization", - "Sub", - None, -] - -allowed_relu_predecessor = [ - "Relu", -] - class ConvertQONNXtoFINN(Transformation): """Converts QONNX dialect to FINN ONNX dialect. @@ -72,14 +62,23 @@ class ConvertQuantActToMultiThreshold(Transformation): "Only Quant nodes with zero-point == 0 are currently supported." ) - # ToDo: Check for activation functions behind (or infront of?) - # the Quant node, such as ReLu + # Check for possible ambiguity in handler selection + valid_predecessors = [] + for cls in QuantActBaseHandler.__subclasses__(): + valid_predecessors.extend(cls.valid_predecessor_op_types) + if len(valid_predecessors) != len(set(valid_predecessors)): + raise RuntimeError( + "Two or more activation handlers declare the same " + "type of valid predecessor node. " + "This leads to ambiguity in the handler selection " + "and must thus be avoided." + ) - # Check that this is an idendity operation - if predecessor_op_type in allowed_identity_predecessor: - handler = QuantIdentityHandler(model, n, node_ind) - elif predecessor_op_type in allowed_relu_predecessor: - handler = QuantReluHandler(model, n, node_ind) + # Try to find a fitting handler for this Quant activation node + for handler_cls in QuantActBaseHandler.__subclasses__(): + if predecessor_op_type in handler_cls.valid_predecessor_op_types: + handler = handler_cls(model, n, node_ind) + break else: raise ValueError( f"Quant nodes in the activation path and with predecessor " @@ -204,25 +203,31 @@ class QuantActBaseHandler(ABC): self._q_node = quant_node self._q_index = quant_node_index + @property + @classmethod + @abstractmethod + def valid_predecessor_op_types(self): + raise NotImplementedError() + + @abstractmethod + def _check_compatibility(self): + raise NotImplementedError() + @abstractmethod def _calculate_act_bias(self): - pass + raise NotImplementedError() @abstractmethod def _calculate_thresholds(self): - pass + raise NotImplementedError() @abstractmethod def _calculate_act_scale(self): - pass + raise NotImplementedError() @abstractmethod def _remove_activation_node(self): - pass - - @abstractmethod - def _check_compatibility(self): - pass + raise NotImplementedError() def _extract_output_datatype(self): dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name @@ -242,6 +247,7 @@ class QuantActBaseHandler(ABC): def replace_quant_node(self): # Check that we actually support what the user is trying to do self._check_compatibility() + # Shorten instance variables model = self._model graph = model.graph @@ -362,6 +368,10 @@ class QuantReluHandler(QuantActBaseHandler): # zero_pt = model.get_initializer(n.input[2]) # signed = q_inst.get_nodeattr("signed") + valid_predecessor_op_types = [ + "Relu", + ] + def _check_compatibility(self): q_inst = getCustomOp(self._q_node) narrow = q_inst.get_nodeattr("narrow") @@ -451,6 +461,12 @@ class QuantIdentityHandler(QuantActBaseHandler): # zero_pt = model.get_initializer(n.input[2]) # signed = q_inst.get_nodeattr("signed") + valid_predecessor_op_types = [ + "BatchNormalization", + "Sub", + None, + ] + def _check_compatibility(self): # Gather parameters to check q_inst = getCustomOp(self._q_node)