Skip to content
Snippets Groups Projects
Commit d8495f34 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added handler for QuantRelu activations.

parent c7381f09
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,10 @@ allowed_identity_predecessor = [
None,
]
allowed_relu_predecessor = [
"Relu",
]
class ConvertQuantActToMultiThreshold(Transformation):
"""Converts Quant nodes in the activation path to MultiThreshold nodes."""
......@@ -52,6 +56,8 @@ class ConvertQuantActToMultiThreshold(Transformation):
# Check that this is an idendity operation
if predecessor_op_type in allowed_identity_predecessor:
handler = QuantIdentityHandler(model, n, node_ind)
elif predecessor_op_type in allowed_relu_predecessor:
handler = QuantReluHandler(model, n, node_ind)
else:
raise RuntimeError(
f"Quant nodes in the activation path and with predecessor "
......@@ -67,6 +73,9 @@ class ConvertQuantActToMultiThreshold(Transformation):
class FoldQuantWeights(Transformation):
"""Merges Quant nodes, which are used as weights into the initializer
of the weight tensor."""
def apply(self, model):
graph = model.graph
node_ind = 0
......@@ -170,6 +179,10 @@ class QuantActBaseHandler(ABC):
def _calculate_act_scale(self):
pass
@abstractmethod
def _remove_activation_node(self):
pass
def _extract_output_datatype(self):
dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name
if "SCALED" in dtype:
......@@ -287,6 +300,9 @@ class QuantActBaseHandler(ABC):
graph.node.insert(running_node_index, mul_node)
running_node_index += 1
# Remove activation node
self._remove_activation_node()
# Remove the Quant node
graph.node.remove(n)
......@@ -294,6 +310,85 @@ class QuantActBaseHandler(ABC):
return self._model
class QuantReluHandler(QuantActBaseHandler):
"""Class for converting a quantized relu operation expressed in the QONNX
dialect to the FINN ONNX dialect."""
# ToDo: zero_pt and signed should have some sort of influence or
# should at least get checked for correct range or value
# zero_pt = model.get_initializer(n.input[2])
# signed = q_inst.get_nodeattr("signed")
def _calculate_act_bias(self):
# No bias allowed for Relu activations, see: https://github.com/Xilinx/
# brevitas/blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
# export/onnx/finn/handler/act.py#L48
bias = np.array([0.0])
return bias
def _calculate_thresholds(self):
# Gather parameters
bit_width = self._model.get_initializer(self._q_node.input[3])
quant_scale = self._model.get_initializer(self._q_node.input[1])
# q_inst = getCustomOp(self._q_node)
# narrow = q_inst.get_nodeattr("narrow")
# Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/
# a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
# onnx/finn/handler/act.py#L21
num_distinct_values = 2 ** bit_width
num_thresholds = int(num_distinct_values - 1)
flat_scale = quant_scale.flatten()
num_scale_channels = flat_scale.shape[0]
step = np.abs(flat_scale)
min_threshold = step / 2
thresholds = np.empty((num_scale_channels, num_thresholds))
for c in range(num_scale_channels):
for t in range(num_thresholds):
thresholds[c][t] = min_threshold[c] + step[c] * t
# ToDo: The index 1 needs to be changed to -1 for the channels last format
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape)
return thresholds
def _calculate_act_scale(self):
# Gather parameters
quant_scale = self._model.get_initializer(self._q_node.input[1])
# Calculate scale, see: https://github.com/Xilinx/brevitas/blob/
# a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
# onnx/finn/handler/act.py#L40
scale = quant_scale
return scale
def _remove_activation_node(self):
# Find the activation node
act_node = self._model.find_direct_predecessors(self._q_node)
if act_node is None:
raise RuntimeError(
"For handling of Relu activations a predecesor to "
"the Quant node must exist."
)
act_node = act_node[0]
if not act_node.op_type == "Relu":
raise RuntimeError(
"The predecesor of the Quant node must be Relu for handling "
"of Relu activations."
)
# Reroute possible predecessors
act_predecessors = self._model.find_direct_predecessors(act_node)
if act_node is not None:
for act_pre in act_predecessors:
act_pre.output[0] = act_node.output[0]
# Remove the activation node
self._model.graph.node.remove(act_node)
class QuantIdentityHandler(QuantActBaseHandler):
"""Class for converting a quantized identity operation expressed in the QONNX
dialect to the FINN ONNX dialect."""
......@@ -377,3 +472,7 @@ class QuantIdentityHandler(QuantActBaseHandler):
assert quant_scale.flatten().item() == 1.0, "Unsupported BIPOLAR scale != 1"
scale = quant_scale * 2
return scale
def _remove_activation_node(self):
# The Quant identity activation has per definition no explicit activation node
return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment