Skip to content
Snippets Groups Projects
Commit 54283158 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[QONNX] make valid_predecessor_op_type a method, fallback to QuantIdentityHandler

parent d9731b3f
No related branches found
No related tags found
No related merge requests found
......@@ -52,9 +52,7 @@ class QuantActBaseHandler(ABC):
self._q_node = quant_node
self._q_index = quant_node_index
@property
@classmethod
@abstractmethod
def valid_predecessor_op_types(self):
"""Defines which op types the preceding node is allowed to have for
this type of activation.
......@@ -284,9 +282,11 @@ class QuantReluHandler(QuantActBaseHandler):
"""Class for converting a quantized relu operation expressed in the QONNX
dialect to the FINN ONNX dialect."""
valid_predecessor_op_types = [
"Relu",
]
@classmethod
def valid_predecessor_op_types(self):
return [
"Relu",
]
def _check_compatibility(self):
if self._q_node.op_type == "Quant":
......@@ -391,15 +391,17 @@ class QuantIdentityHandler(QuantActBaseHandler):
these are equivalent to quantized identity activations.
"""
valid_predecessor_op_types = [
"BatchNormalization",
"Sub",
"Add",
"Mul",
"Div",
"DebugMarker",
None,
]
@classmethod
def valid_predecessor_op_types(self):
return [
"BatchNormalization",
"Sub",
"Add",
"Mul",
"Div",
"DebugMarker",
None,
]
def _check_compatibility(self):
# Gather parameters to check
......
......@@ -30,7 +30,10 @@
import warnings
from qonnx.transformation.base import Transformation
from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler
from finn.transformation.qonnx.qonnx_activation_handlers import (
QuantActBaseHandler,
QuantIdentityHandler,
)
def default_filter_function_generator(max_multithreshold_bit_width=8):
......@@ -127,7 +130,7 @@ class ConvertQuantActToMultiThreshold(Transformation):
# Check for possible ambiguity in handler selection
valid_predecessors = []
for cls in QuantActBaseHandler.__subclasses__():
valid_predecessors.extend(cls.valid_predecessor_op_types)
valid_predecessors.extend(cls.valid_predecessor_op_types())
if len(valid_predecessors) != len(set(valid_predecessors)):
raise RuntimeError(
"Two or more activation handlers declare the same "
......@@ -138,16 +141,15 @@ class ConvertQuantActToMultiThreshold(Transformation):
# Try to find a fitting handler for this Quant activation node
for handler_cls in QuantActBaseHandler.__subclasses__():
if predecessor_op_type in handler_cls.valid_predecessor_op_types:
if predecessor_op_type in handler_cls.valid_predecessor_op_types():
handler = handler_cls(model, n, node_ind)
break
else:
raise ValueError(
f"Quant nodes in the activation path and with predecessor "
f"nodes of type {predecessor_op_type} are currently not "
f"supported by FINN and can not be converted to "
f"MultiThreshold nodes."
)
# fall back to QuantIdentityHandler here
# it may still not work due to its particular restrictions,
# but better than just erroring out without trying
handler = QuantIdentityHandler(model, n, node_ind)
model = handler.replace_quant_node()
graph_modified = True
return (model, graph_modified)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment