Skip to content
Snippets Groups Projects
Commit a74bb6f0 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Moved allowed_predecessors definition into handler classes.

parent e0da0526
No related branches found
No related tags found
No related merge requests found
......@@ -10,16 +10,6 @@ from finn.transformation.base import Transformation
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
allowed_identity_predecessor = [
"BatchNormalization",
"Sub",
None,
]
allowed_relu_predecessor = [
"Relu",
]
class ConvertQONNXtoFINN(Transformation):
"""Converts QONNX dialect to FINN ONNX dialect.
......@@ -72,14 +62,23 @@ class ConvertQuantActToMultiThreshold(Transformation):
"Only Quant nodes with zero-point == 0 are currently supported."
)
# ToDo: Check for activation functions behind (or infront of?)
# the Quant node, such as ReLu
# Check for possible ambiguity in handler selection
valid_predecessors = []
for cls in QuantActBaseHandler.__subclasses__():
valid_predecessors.extend(cls.valid_predecessor_op_types)
if len(valid_predecessors) != len(set(valid_predecessors)):
raise RuntimeError(
"Two or more activation handlers declare the same "
"type of valid predecessor node. "
"This leads to ambiguity in the handler selection "
"and must thus be avoided."
)
# Check that this is an idendity operation
if predecessor_op_type in allowed_identity_predecessor:
handler = QuantIdentityHandler(model, n, node_ind)
elif predecessor_op_type in allowed_relu_predecessor:
handler = QuantReluHandler(model, n, node_ind)
# Try to find a fitting handler for this Quant activation node
for handler_cls in QuantActBaseHandler.__subclasses__():
if predecessor_op_type in handler_cls.valid_predecessor_op_types:
handler = handler_cls(model, n, node_ind)
break
else:
raise ValueError(
f"Quant nodes in the activation path and with predecessor "
......@@ -204,25 +203,31 @@ class QuantActBaseHandler(ABC):
self._q_node = quant_node
self._q_index = quant_node_index
@property
@classmethod
@abstractmethod
def valid_predecessor_op_types(self):
raise NotImplementedError()
@abstractmethod
def _check_compatibility(self):
raise NotImplementedError()
@abstractmethod
def _calculate_act_bias(self):
pass
raise NotImplementedError()
@abstractmethod
def _calculate_thresholds(self):
pass
raise NotImplementedError()
@abstractmethod
def _calculate_act_scale(self):
pass
raise NotImplementedError()
@abstractmethod
def _remove_activation_node(self):
pass
@abstractmethod
def _check_compatibility(self):
pass
raise NotImplementedError()
def _extract_output_datatype(self):
dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name
......@@ -242,6 +247,7 @@ class QuantActBaseHandler(ABC):
def replace_quant_node(self):
# Check that we actually support what the user is trying to do
self._check_compatibility()
# Shorten instance variables
model = self._model
graph = model.graph
......@@ -362,6 +368,10 @@ class QuantReluHandler(QuantActBaseHandler):
# zero_pt = model.get_initializer(n.input[2])
# signed = q_inst.get_nodeattr("signed")
valid_predecessor_op_types = [
"Relu",
]
def _check_compatibility(self):
q_inst = getCustomOp(self._q_node)
narrow = q_inst.get_nodeattr("narrow")
......@@ -451,6 +461,12 @@ class QuantIdentityHandler(QuantActBaseHandler):
# zero_pt = model.get_initializer(n.input[2])
# signed = q_inst.get_nodeattr("signed")
valid_predecessor_op_types = [
"BatchNormalization",
"Sub",
None,
]
def _check_compatibility(self):
# Gather parameters to check
q_inst = getCustomOp(self._q_node)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment