Skip to content
Snippets Groups Projects
Commit 28f93d5a authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Refactored filter lambda for ConvertQuantActToMultiThreshold transformation to...

Refactored filter lambda for ConvertQuantActToMultiThreshold transformation to only use the function for filtering.
parent ff3d1413
No related branches found
No related tags found
No related merge requests found
......@@ -96,6 +96,9 @@ from finn.transformation.infer_shapes import InferShapes
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.transformation.qonnx.quant_act_to_multithreshold import (
default_filter_function_generator,
)
from finn.transformation.streamline import Streamline
from finn.transformation.streamline.reorder import MakeMaxPoolNHWC
from finn.util.config import extract_model_config_to_json
......@@ -208,7 +211,9 @@ def step_qonnx_to_finn(model: ModelWrapper, cfg: DataflowBuildConfig):
# QONNX to FINN-ONNX
model = model.transform(
ConvertQONNXtoFINN(
max_multithreshold_bit_width=cfg.max_multithreshold_bit_width
filter_function=default_filter_function_generator(
max_multithreshold_bit_width=cfg.max_multithreshold_bit_width
)
)
)
......
......@@ -36,6 +36,7 @@ from finn.transformation.qonnx.fold_quant_weights import FoldQuantWeights
from finn.transformation.qonnx.infer_QuantAvgPool2d import AvgPoolAndTruncToQuantAvgPool
from finn.transformation.qonnx.quant_act_to_multithreshold import (
ConvertQuantActToMultiThreshold,
default_filter_function_generator,
)
from finn.transformation.remove import RemoveEmptyPadding
......@@ -47,31 +48,28 @@ class ConvertQONNXtoFINN(Transformation):
the activations.
If incompatibilities are found a ValueError or RuntimeError is raised.
The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda`
present a way to control which Quant and BipolarQuant nodes in the activation path
are converted to MultiThreshold nodes.
The filters which are represented by `max_multithreshold_bit_width` and
`filter_lambda` are internally connected by an `AND` operation. A warning
will be emitted when a Quant node is not converted to a MultiThreshold node.
The optional keyword argument `filter_function`
presents a way to control which Quant and BipolarQuant nodes in the activation path
are converted to MultiThreshold nodes. A warning will be emitted when a Quant node
is not converted to a MultiThreshold node.
:param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is
checked against the bit width of any given Quant node and the transformation to a
MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger
than value of max_multithreshold_bit_with. Defaults to: 8
:type max_multithreshold_bit_width: `int`, optional
:param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated
by this lambda function. If the function returns False,
:param filter_function: Each candidate Quant and BinaryQant node is first evaluated
by this function. If the function returns False,
then the node is not converted to a MultiTrheshold node.
Defaults to: lambda q_node: True
:type filter_lambda: `lambda`, optional
The function is given the model and candidate node as parameters.
Per default a filter function is inserted, which disables the conversion of
Quant nodes, which have a bit width of larger than 8.
Defaults to: default_filter_function_generator(max_multithreshold_bit_width=8)
"""
def __init__(
self, max_multithreshold_bit_width=8, filter_lambda=lambda q_node: True
self,
filter_function=default_filter_function_generator(
max_multithreshold_bit_width=8
),
):
super().__init__()
self.max_multithreshold_bit_width = max_multithreshold_bit_width
self._filter_lambda = filter_lambda
self._filter_function = filter_function
def apply(self, model):
# Extract the bias from Conv node
......@@ -86,8 +84,7 @@ class ConvertQONNXtoFINN(Transformation):
# Convert activations
model = model.transform(
ConvertQuantActToMultiThreshold(
max_multithreshold_bit_width=self.max_multithreshold_bit_width,
filter_lambda=self._filter_lambda,
filter_function=self._filter_function,
)
)
# Recompute datatypes
......
......@@ -33,35 +33,64 @@ from finn.transformation.base import Transformation
from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler
def default_filter_function_generator(max_multithreshold_bit_width=8):
"""
This function generates the default filter function for the
ConvertQuantActToMultiThreshold transformation. Per default the returned
function disables the conversion of Quant nodes which have a bit width above 8 bit.
This function generator can be used as a template to write custom
filter functions.
"""
def filter_function(model, q_node):
if q_node.op_type == "Quant":
bit_width = model.get_initializer(q_node.input[3])
elif q_node.op_type == "BipolarQuant":
bit_width = 1.0
else:
raise RuntimeError("Got an unexpected quantizer node type")
if bit_width is None:
raise ValueError("Quant nodes must have a static bit width.")
if bit_width > max_multithreshold_bit_width:
warnings.warn(
f'The Quant node with name: "{q_node.name}" was not converted to a '
f"MultiThreshold node, because its bit width of {bit_width} is "
f"higher than the configured maximum bit width of "
f"{max_multithreshold_bit_width}."
)
return False
return True
return filter_function
class ConvertQuantActToMultiThreshold(Transformation):
"""
Converts Quant nodes in the activation path to MultiThreshold nodes.
The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda`
present a way to control which Quant and BipolarQuant nodes in the activation path
are converted to MultiThreshold nodes.
The filters which are represented by `max_multithreshold_bit_width` and
`filter_lambda` are internally connected by an `AND` operation. A warning
will be emitted when a Quant node is not converted to a MultiThreshold node.
:param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is
checked against the bit width of any given Quant node and the transformation to a
MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger
than value of max_multithreshold_bit_with. Defaults to: 8
:type max_multithreshold_bit_width: `int`, optional
:param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated
by this lambda function. If the function returns False,
The optional keyword argument `filter_function`
presents a way to control which Quant and BipolarQuant nodes in the activation path
are converted to MultiThreshold nodes. A warning will be emitted when a Quant node
is not converted to a MultiThreshold node.
:param filter_function: Each candidate Quant and BinaryQant node is first evaluated
by this function. If the function returns False,
then the node is not converted to a MultiTrheshold node.
Defaults to: lambda q_node: True
:type filter_lambda: `lambda`, optional
The function is given the model and candidate node as parameters.
Per default a filter function is inserted, which disables the conversion of
Quant nodes, which have a bit width of larger than 8.
Defaults to: default_filter_function_generator(max_multithreshold_bit_width=8)
"""
def __init__(
self, max_multithreshold_bit_width=8, filter_lambda=lambda q_node: True
self,
filter_function=default_filter_function_generator(
max_multithreshold_bit_width=8
),
):
super().__init__()
self.max_multithreshold_bit_width = max_multithreshold_bit_width
self._filter_lambda = filter_lambda
self._filter_function = filter_function
def apply(self, model):
graph = model.graph
......@@ -91,28 +120,11 @@ class ConvertQuantActToMultiThreshold(Transformation):
"Only Quant nodes with zero-point == 0 are currently supported."
)
# Check if the bit width is low enough
if n.op_type == "Quant":
bit_width = model.get_initializer(n.input[3])
elif n.op_type == "BipolarQuant":
bit_width = 1.0
else:
raise RuntimeError("Got an unexpected quantizer node type")
if bit_width is None:
raise ValueError("Quant nodes must have a static bit width.")
if bit_width > self.max_multithreshold_bit_width:
warnings.warn(
f'The Quant node with name: "{n.name}" was not converted to a '
f"MultiThreshold node, because its bit width of {bit_width} is "
f"higher than the configured maximum bit width of "
f"{self.max_multithreshold_bit_width}."
)
continue
# Check that this node passes the user filter
if not self._filter_lambda(n):
if not self._filter_function(model, n):
warnings.warn(
f'The Quant node with name: "{n.name}" was not converted to a '
f"MultiThreshold node, because the filtering lambda function "
f"MultiThreshold node, because the filtering function "
f"returned False for this node."
)
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment