From 5428315898e1656653c543a704ca61adb4e3c9df Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@amd.com>
Date: Wed, 11 Jan 2023 12:29:54 +0300
Subject: [PATCH] [QONNX] make valid_predecessor_op_type a method, fallback to
 QuantIdentityHandler

---
 .../qonnx/qonnx_activation_handlers.py        | 30 ++++++++++---------
 .../qonnx/quant_act_to_multithreshold.py      | 20 +++++++------
 2 files changed, 27 insertions(+), 23 deletions(-)

diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
index a50a58507..9819086d8 100644
--- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py
+++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
@@ -52,9 +52,7 @@ class QuantActBaseHandler(ABC):
         self._q_node = quant_node
         self._q_index = quant_node_index
 
-    @property
     @classmethod
-    @abstractmethod
     def valid_predecessor_op_types(self):
         """Defines which op types the preceding node is allowed to have for
         this type of activation.
@@ -284,9 +282,11 @@ class QuantReluHandler(QuantActBaseHandler):
     """Class for converting a quantized relu operation expressed in the QONNX
     dialect to the FINN ONNX dialect."""
 
-    valid_predecessor_op_types = [
-        "Relu",
-    ]
+    @classmethod
+    def valid_predecessor_op_types(self):
+        return [
+            "Relu",
+        ]
 
     def _check_compatibility(self):
         if self._q_node.op_type == "Quant":
@@ -391,15 +391,17 @@ class QuantIdentityHandler(QuantActBaseHandler):
     these are equivalent to quantized identity activations.
     """
 
-    valid_predecessor_op_types = [
-        "BatchNormalization",
-        "Sub",
-        "Add",
-        "Mul",
-        "Div",
-        "DebugMarker",
-        None,
-    ]
+    @classmethod
+    def valid_predecessor_op_types(self):
+        return [
+            "BatchNormalization",
+            "Sub",
+            "Add",
+            "Mul",
+            "Div",
+            "DebugMarker",
+            None,
+        ]
 
     def _check_compatibility(self):
         # Gather parameters to check
diff --git a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
index 77025ecdf..e0f893f35 100644
--- a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
+++ b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
@@ -30,7 +30,10 @@
 import warnings
 from qonnx.transformation.base import Transformation
 
-from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler
+from finn.transformation.qonnx.qonnx_activation_handlers import (
+    QuantActBaseHandler,
+    QuantIdentityHandler,
+)
 
 
 def default_filter_function_generator(max_multithreshold_bit_width=8):
@@ -127,7 +130,7 @@ class ConvertQuantActToMultiThreshold(Transformation):
                 # Check for possible ambiguity in handler selection
                 valid_predecessors = []
                 for cls in QuantActBaseHandler.__subclasses__():
-                    valid_predecessors.extend(cls.valid_predecessor_op_types)
+                    valid_predecessors.extend(cls.valid_predecessor_op_types())
                 if len(valid_predecessors) != len(set(valid_predecessors)):
                     raise RuntimeError(
                         "Two or more activation handlers declare the same "
@@ -138,16 +141,15 @@ class ConvertQuantActToMultiThreshold(Transformation):
 
                 # Try to find a fitting handler for this Quant activation node
                 for handler_cls in QuantActBaseHandler.__subclasses__():
-                    if predecessor_op_type in handler_cls.valid_predecessor_op_types:
+                    if predecessor_op_type in handler_cls.valid_predecessor_op_types():
                         handler = handler_cls(model, n, node_ind)
                         break
                 else:
-                    raise ValueError(
-                        f"Quant nodes in the activation path and with predecessor "
-                        f"nodes of type {predecessor_op_type} are currently not "
-                        f"supported by FINN and can not be converted to "
-                        f"MultiThreshold nodes."
-                    )
+                    # fall back to QuantIdentityHandler here
+                    # it may still not work due to its particular restrictions,
+                    # but better than just erroring out without trying
+                    handler = QuantIdentityHandler(model, n, node_ind)
+
                 model = handler.replace_quant_node()
                 graph_modified = True
                 return (model, graph_modified)
-- 
GitLab