From 5428315898e1656653c543a704ca61adb4e3c9df Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@amd.com> Date: Wed, 11 Jan 2023 12:29:54 +0300 Subject: [PATCH] [QONNX] make valid_predecessor_op_type a method, fallback to QuantIdentityHandler --- .../qonnx/qonnx_activation_handlers.py | 30 ++++++++++--------- .../qonnx/quant_act_to_multithreshold.py | 20 +++++++------ 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index a50a58507..9819086d8 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -52,9 +52,7 @@ class QuantActBaseHandler(ABC): self._q_node = quant_node self._q_index = quant_node_index - @property @classmethod - @abstractmethod def valid_predecessor_op_types(self): """Defines which op types the preceding node is allowed to have for this type of activation. @@ -284,9 +282,11 @@ class QuantReluHandler(QuantActBaseHandler): """Class for converting a quantized relu operation expressed in the QONNX dialect to the FINN ONNX dialect.""" - valid_predecessor_op_types = [ - "Relu", - ] + @classmethod + def valid_predecessor_op_types(self): + return [ + "Relu", + ] def _check_compatibility(self): if self._q_node.op_type == "Quant": @@ -391,15 +391,17 @@ class QuantIdentityHandler(QuantActBaseHandler): these are equivalent to quantized identity activations. """ - valid_predecessor_op_types = [ - "BatchNormalization", - "Sub", - "Add", - "Mul", - "Div", - "DebugMarker", - None, - ] + @classmethod + def valid_predecessor_op_types(self): + return [ + "BatchNormalization", + "Sub", + "Add", + "Mul", + "Div", + "DebugMarker", + None, + ] def _check_compatibility(self): # Gather parameters to check diff --git a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py index 77025ecdf..e0f893f35 100644 --- a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py +++ b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py @@ -30,7 +30,10 @@ import warnings from qonnx.transformation.base import Transformation -from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler +from finn.transformation.qonnx.qonnx_activation_handlers import ( + QuantActBaseHandler, + QuantIdentityHandler, +) def default_filter_function_generator(max_multithreshold_bit_width=8): @@ -127,7 +130,7 @@ class ConvertQuantActToMultiThreshold(Transformation): # Check for possible ambiguity in handler selection valid_predecessors = [] for cls in QuantActBaseHandler.__subclasses__(): - valid_predecessors.extend(cls.valid_predecessor_op_types) + valid_predecessors.extend(cls.valid_predecessor_op_types()) if len(valid_predecessors) != len(set(valid_predecessors)): raise RuntimeError( "Two or more activation handlers declare the same " @@ -138,16 +141,15 @@ class ConvertQuantActToMultiThreshold(Transformation): # Try to find a fitting handler for this Quant activation node for handler_cls in QuantActBaseHandler.__subclasses__(): - if predecessor_op_type in handler_cls.valid_predecessor_op_types: + if predecessor_op_type in handler_cls.valid_predecessor_op_types(): handler = handler_cls(model, n, node_ind) break else: - raise ValueError( - f"Quant nodes in the activation path and with predecessor " - f"nodes of type {predecessor_op_type} are currently not " - f"supported by FINN and can not be converted to " - f"MultiThreshold nodes." - ) + # fall back to QuantIdentityHandler here + # it may still not work due to its particular restrictions, + # but better than just erroring out without trying + handler = QuantIdentityHandler(model, n, node_ind) + model = handler.replace_quant_node() graph_modified = True return (model, graph_modified) -- GitLab