Skip to content
Snippets Groups Projects
Commit 5f3cf652 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added support for BinaryQuant conversion.

parent 01805541
No related branches found
No related tags found
No related merge requests found
...@@ -86,10 +86,10 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg ...@@ -86,10 +86,10 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies # git-based Python repo dependencies
# these are installed in editable mode for easier co-development # these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="138f177a2729334444d5f58d96c32a1bf4d4b6c2" ARG FINN_BASE_COMMIT="ec3997c3f4276f7746bfd08a1a9508bd02a132fa"
ARG QONNX_COMMIT="02b15f56d199576cefe9aff672d5e009349e402a" ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="462f86cdc60f9915baf13afd1676fb21da44c2ee" ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
ARG PYVERILATOR_COMMIT="0c3eb9343500fc1352a02c020a736c8c2db47e8e" ARG PYVERILATOR_COMMIT="0c3eb9343500fc1352a02c020a736c8c2db47e8e"
ARG CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4" ARG CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4"
ARG HLSLIB_COMMIT="fbb07135b3d991602e8abe3f2c51212c11fd392b" ARG HLSLIB_COMMIT="fbb07135b3d991602e8abe3f2c51212c11fd392b"
......
...@@ -31,10 +31,24 @@ from onnx import TensorProto, helper ...@@ -31,10 +31,24 @@ from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation from finn.transformation.base import Transformation
from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_shapes import InferShapes
def get_dtype(bit_width: int, signed: bool) -> DataType:
bit_width = int(bit_width)
signed = bool(signed)
if bit_width == 1.0:
finn_dt = DataType["BIPOLAR"]
else:
if signed:
finn_dt = DataType["INT" + str(bit_width)]
else:
finn_dt = DataType["UINT" + str(bit_width)]
return finn_dt
class FoldQuantWeights(Transformation): class FoldQuantWeights(Transformation):
"""Merges Quant nodes, which are used as weights into the initializer """Merges Quant nodes, which are used as weights into the initializer
of the weight tensor. of the weight tensor.
...@@ -47,7 +61,7 @@ class FoldQuantWeights(Transformation): ...@@ -47,7 +61,7 @@ class FoldQuantWeights(Transformation):
execution_context = model.make_empty_exec_context() execution_context = model.make_empty_exec_context()
for n in graph.node: for n in graph.node:
node_ind += 1 node_ind += 1
if n.op_type == "Quant": if n.op_type == "Quant" or n.op_type == "BinaryQuant":
node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input)) node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input))
node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits)) node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits))
node_out = n.output[0] node_out = n.output[0]
...@@ -55,18 +69,25 @@ class FoldQuantWeights(Transformation): ...@@ -55,18 +69,25 @@ class FoldQuantWeights(Transformation):
ishape = model.get_tensor_shape(n.input[0]) ishape = model.get_tensor_shape(n.input[0])
is_const_shape = (n.op_type == "Shape") and (ishape is not None) is_const_shape = (n.op_type == "Shape") and (ishape is not None)
if is_all_constant_inputs or is_const_shape: if is_all_constant_inputs or is_const_shape:
if not model.get_initializer(n.input[2]) == 0: if (
n.op_type == "Quant"
and not model.get_initializer(n.input[2]) == 0
):
raise ValueError( raise ValueError(
"Only Quant nodes with zero-point == 0 " "Only Quant nodes with zero-point == 0 "
"are currently supported." "are currently supported."
) )
scale = model.get_initializer(n.input[1])
unity_scale = (scale.flatten() == 1.0).all()
# this node has no dynamic inputs, only constant ones -- so we can # this node has no dynamic inputs, only constant ones -- so we can
# do constant folding. # do constant folding.
oxe.execute_node(n, execution_context, graph) oxe.execute_node(n, execution_context, graph)
q_node_output = execution_context[node_out] q_node_output = execution_context[node_out]
# Check if the datatype can be directly constant folded # Check we can directly constant fold
dtype = model.get_tensor_datatype(n.output[0]) if unity_scale:
if "SCALED" in dtype.name: # use the execution result as an initializer
model.set_initializer(node_out, q_node_output)
else:
# Reshape scale for Conv if required # Reshape scale for Conv if required
if model.is_fork_node(n): if model.is_fork_node(n):
raise RuntimeError( raise RuntimeError(
...@@ -95,14 +116,23 @@ class FoldQuantWeights(Transformation): ...@@ -95,14 +116,23 @@ class FoldQuantWeights(Transformation):
f"at node: {target_node}." f"at node: {target_node}."
) )
# For buth mul and Add: # For both mul and Add:
# Move the scale factor behind the next operator # Move the scale factor behind the next operator
scale = model.get_initializer(n.input[1]) scale = model.get_initializer(n.input[1])
new_initializer = q_node_output / scale new_initializer = q_node_output / scale
# Round, to correct for floating point errors # Round, to correct for floating point errors
new_initializer = np.round(new_initializer) new_initializer = np.round(new_initializer)
model.set_initializer(node_out, new_initializer) model.set_initializer(node_out, new_initializer)
new_dtype = DataType[dtype.name.replace("SCALED", "")] if n.op_type == "Quant":
bit_width = model.get_initializer(n.input[3])
q_inst = getCustomOp(n)
signed = q_inst.get_nodeattr("signed")
elif n.op_type == "BinaryQuant":
bit_width = 1.0
signed = True
else:
raise RuntimeError("Got an unexpected quantizer node type")
new_dtype = get_dtype(bit_width, signed)
model.set_tensor_datatype(node_out, new_dtype) model.set_tensor_datatype(node_out, new_dtype)
if target_node.op_type == "Conv" and len(scale.shape) > 0: if target_node.op_type == "Conv" and len(scale.shape) > 0:
...@@ -175,9 +205,6 @@ class FoldQuantWeights(Transformation): ...@@ -175,9 +205,6 @@ class FoldQuantWeights(Transformation):
) )
graph.node.insert(node_ind, div_node) graph.node.insert(node_ind, div_node)
else:
# use the execution result as an initializer
model.set_initializer(node_out, q_node_output)
# remove old node # remove old node
graph.node.remove(n) graph.node.remove(n)
graph_modified = True graph_modified = True
......
...@@ -160,7 +160,7 @@ class QuantActBaseHandler(ABC): ...@@ -160,7 +160,7 @@ class QuantActBaseHandler(ABC):
# behind the MultiTrheshold nodes. # behind the MultiTrheshold nodes.
bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_scalar and bias_scalar and False: if scale_scalar and bias_scalar and self._q_node.op_type == "BinaryQuant":
# Get Quant parameters # Get Quant parameters
mul_scale = np.atleast_1d(mul_scale) mul_scale = np.atleast_1d(mul_scale)
# ONNX only accepts 64bit floats as attributes # ONNX only accepts 64bit floats as attributes
...@@ -173,10 +173,11 @@ class QuantActBaseHandler(ABC): ...@@ -173,10 +173,11 @@ class QuantActBaseHandler(ABC):
# FINN applies scale first then bias, # FINN applies scale first then bias,
# which is the other way around in Brevitas, # which is the other way around in Brevitas,
# we thus need to adjust the bias in the MultiThreshold node # we thus need to adjust the bias in the MultiThreshold node
mt_inst.set_nodeattr("out_bias", adder_bias[0] * mul_scale[0]) finn_bias = adder_bias[0] * mul_scale[0]
mt_inst.set_nodeattr("out_bias", finn_bias)
# If the bias and scale are integers, then the output will be as well. # If the bias and scale are integers, then the output will be as well.
if adder_bias % 1 == 0 and mul_scale % 1 == 0: if finn_bias % 1 == 0 and mul_scale % 1 == 0:
mt_inst.set_nodeattr("out_dtype", out_dtype) mt_inst.set_nodeattr("out_dtype", out_dtype)
else: else:
# Set datatype # Set datatype
...@@ -289,19 +290,24 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -289,19 +290,24 @@ class QuantReluHandler(QuantActBaseHandler):
] ]
def _check_compatibility(self): def _check_compatibility(self):
q_inst = getCustomOp(self._q_node) if self._q_node.op_type == "Quant":
narrow = q_inst.get_nodeattr("narrow") q_inst = getCustomOp(self._q_node)
signed = q_inst.get_nodeattr("signed") narrow = q_inst.get_nodeattr("narrow")
if signed or narrow: signed = q_inst.get_nodeattr("signed")
raise ValueError( if signed or narrow:
"FINN only supports unsigned and non-narrow Quant nodes " raise ValueError(
"for Relu activations." "FINN only supports unsigned and non-narrow Quant nodes "
) "for Relu activations."
if not self._model.get_initializer(self._q_node.input[2]) == 0: )
raise ValueError( if not self._model.get_initializer(self._q_node.input[2]) == 0:
"Only Quant nodes with zero-point == 0 " raise ValueError(
"are currently supported for ReLu activations." "Only Quant nodes with zero-point == 0 "
) "are currently supported for ReLu activations."
)
elif self._q_node.op_type == "BinaryQuant":
return
else:
raise RuntimeError("Got an unexpected quantizer node type")
def _calculate_act_bias(self): def _calculate_act_bias(self):
# No bias allowed for Relu activations, see: https://github.com/Xilinx/ # No bias allowed for Relu activations, see: https://github.com/Xilinx/
...@@ -312,7 +318,12 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -312,7 +318,12 @@ class QuantReluHandler(QuantActBaseHandler):
def _calculate_thresholds(self): def _calculate_thresholds(self):
# Gather parameters # Gather parameters
bit_width = self._model.get_initializer(self._q_node.input[3]) if self._q_node.op_type == "Quant":
bit_width = self._model.get_initializer(self._q_node.input[3])
elif self._q_node.op_type == "BinaryQuant":
bit_width = 1.0
else:
raise RuntimeError("Got an unexpected quantizer node type")
quant_scale = self._model.get_initializer(self._q_node.input[1]) quant_scale = self._model.get_initializer(self._q_node.input[1])
# q_inst = getCustomOp(self._q_node) # q_inst = getCustomOp(self._q_node)
# narrow = q_inst.get_nodeattr("narrow") # narrow = q_inst.get_nodeattr("narrow")
...@@ -388,27 +399,42 @@ class QuantIdentityHandler(QuantActBaseHandler): ...@@ -388,27 +399,42 @@ class QuantIdentityHandler(QuantActBaseHandler):
def _check_compatibility(self): def _check_compatibility(self):
# Gather parameters to check # Gather parameters to check
q_inst = getCustomOp(self._q_node) if self._q_node.op_type == "Quant":
signed = q_inst.get_nodeattr("signed") q_inst = getCustomOp(self._q_node)
if not signed: signed = q_inst.get_nodeattr("signed")
raise ValueError( if not signed:
"FINN only supports signed Quant nodes for identity activations." raise ValueError(
) "FINN only supports signed Quant nodes for identity activations."
if not self._model.get_initializer(self._q_node.input[2]) == 0: )
raise ValueError( if not self._model.get_initializer(self._q_node.input[2]) == 0:
"Only Quant nodes with zero-point == 0 " raise ValueError(
"are currently supported for identity activations." "Only Quant nodes with zero-point == 0 "
) "are currently supported for identity activations."
)
elif self._q_node.op_type == "BinaryQuant":
quant_scale = self._model.get_initializer(self._q_node.input[1])
if (quant_scale.flatten().shape[0] != 1) or quant_scale.flatten()[0] != 1.0:
raise ValueError(
"FINN only supports Bipolar identity activations "
"with out per channel scaling and the scaling must be 1. "
)
else:
raise RuntimeError("Got an unexpected quantizer node type")
def _calculate_act_bias(self): def _calculate_act_bias(self):
# Gather parameters # Gather parameters
bit_width = self._model.get_initializer(self._q_node.input[3])
q_inst = getCustomOp(self._q_node) q_inst = getCustomOp(self._q_node)
narrow = q_inst.get_nodeattr("narrow") if self._q_node.op_type == "Quant":
bit_width = self._model.get_initializer(self._q_node.input[3])
narrow = q_inst.get_nodeattr("narrow")
elif self._q_node.op_type == "BinaryQuant":
bit_width = 1.0
else:
raise RuntimeError("Got an unexpected quantizer node type")
# Calculate bias, see: https://github.com/Xilinx/brevitas/blob/ # Calculate bias, see: https://github.com/Xilinx/brevitas/blob/
# a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
# onnx/finn/handler/act.py#L64 # onnx/finn/handler/act.py#L64
if bit_width == 1: if bit_width == 1.0:
bias = np.array([-0.5]) bias = np.array([-0.5])
else: else:
if narrow: if narrow:
...@@ -420,45 +446,62 @@ class QuantIdentityHandler(QuantActBaseHandler): ...@@ -420,45 +446,62 @@ class QuantIdentityHandler(QuantActBaseHandler):
def _calculate_thresholds(self): def _calculate_thresholds(self):
# Gather parameters # Gather parameters
bit_width = self._model.get_initializer(self._q_node.input[3])
quant_scale = self._model.get_initializer(self._q_node.input[1]) quant_scale = self._model.get_initializer(self._q_node.input[1])
q_inst = getCustomOp(self._q_node) q_inst = getCustomOp(self._q_node)
narrow = q_inst.get_nodeattr("narrow") if self._q_node.op_type == "Quant":
bit_width = self._model.get_initializer(self._q_node.input[3])
narrow = q_inst.get_nodeattr("narrow")
elif self._q_node.op_type == "BinaryQuant":
bit_width = 1.0
else:
raise RuntimeError("Got an unexpected quantizer node type")
# Calculate thersholds, see: https://github.com/Xilinx/brevitas/ # Calculate thersholds, see: https://github.com/Xilinx/brevitas/
# blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
# export/onnx/finn/handler/act.py#L76 # export/onnx/finn/handler/act.py#L76
if narrow: if bit_width == 1.0:
num_distinct_values = 2 ** bit_width - 1 thresholds = np.empty([1, 1])
thresholds[0] = 0
return thresholds
else: else:
num_distinct_values = 2 ** bit_width if narrow:
num_distinct_values = 2 ** bit_width - 1
num_thresholds = int(num_distinct_values - 1) else:
flat_scale = quant_scale.flatten() num_distinct_values = 2 ** bit_width
num_scale_channels = flat_scale.shape[0]
step = np.abs(flat_scale) num_thresholds = int(num_distinct_values - 1)
half_step = step / 2.0 flat_scale = quant_scale.flatten()
thresholds = np.empty((num_scale_channels, num_thresholds)) num_scale_channels = flat_scale.shape[0]
# compute the value of the smallest threshold, we'll neg-bias all step = np.abs(flat_scale)
# generated thresholds by this much half_step = step / 2.0
min_threshold = -half_step - step * ((num_thresholds // 2) - 1) thresholds = np.empty((num_scale_channels, num_thresholds))
if not narrow: # compute the value of the smallest threshold, we'll neg-bias all
min_threshold -= step # generated thresholds by this much
for c in range(num_scale_channels): min_threshold = -half_step - step * ((num_thresholds // 2) - 1)
for t in range(num_thresholds): if not narrow:
thresholds[c][t] = min_threshold[c] + step[c] * t min_threshold -= step
for c in range(num_scale_channels):
# ToDo: The index 1 needs to be changed to -1 for the channels last format for t in range(num_thresholds):
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] thresholds[c][t] = min_threshold[c] + step[c] * t
final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape: # ToDo: The index 1 needs to be changed to -1 for the channels last format
thresholds = np.broadcast_to(thresholds, final_shape) num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[
1
return thresholds ]
final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape)
return thresholds
def _calculate_act_scale(self): def _calculate_act_scale(self):
# Gather parameters # Gather parameters
bit_width = self._model.get_initializer(self._q_node.input[3]) if self._q_node.op_type == "Quant":
bit_width = self._model.get_initializer(self._q_node.input[3])
elif self._q_node.op_type == "BinaryQuant":
bit_width = 1.0
else:
raise RuntimeError("Got an unexpected quantizer node type")
quant_scale = self._model.get_initializer(self._q_node.input[1]) quant_scale = self._model.get_initializer(self._q_node.input[1])
# Calculate scale, see: https://github.com/Xilinx/brevitas/ # Calculate scale, see: https://github.com/Xilinx/brevitas/
# blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
...@@ -466,12 +509,10 @@ class QuantIdentityHandler(QuantActBaseHandler): ...@@ -466,12 +509,10 @@ class QuantIdentityHandler(QuantActBaseHandler):
if bit_width != 1: if bit_width != 1:
scale = quant_scale scale = quant_scale
else: else:
# ToDo: This needs testing and/or rewriting when the BinarayQuant op
# comes around
assert ( assert (
quant_scale.flatten().shape[0] == 1 quant_scale.flatten().shape[0] == 1
), "Unsupported BIPOLAR per channel scale" ), "Unsupported BIPOLAR per channel scale"
assert quant_scale.flatten().item() == 1.0, "Unsupported BIPOLAR scale != 1" assert quant_scale.flatten()[0] == 1.0, "Unsupported BIPOLAR scale != 1"
scale = quant_scale * 2 scale = quant_scale * 2
return scale return scale
......
...@@ -70,7 +70,7 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -70,7 +70,7 @@ class ConvertQuantActToMultiThreshold(Transformation):
for n in graph.node: for n in graph.node:
node_ind += 1 node_ind += 1
if n.op_type == "Quant": if n.op_type == "Quant" or n.op_type == "BinaryQuant":
# Check that the node is in the activation path # Check that the node is in the activation path
inp = model.get_initializer(n.input[0]) inp = model.get_initializer(n.input[0])
out = model.get_initializer(n.output[0]) out = model.get_initializer(n.output[0])
...@@ -83,15 +83,21 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -83,15 +83,21 @@ class ConvertQuantActToMultiThreshold(Transformation):
predecessor_op_type = predecessor predecessor_op_type = predecessor
if model.is_fork_node(n): if model.is_fork_node(n):
raise ValueError( raise ValueError(
"Forking Quant nodes are not currently supported by FINN." "Forking Quant/BinaryQuant nodes are currently "
"not supported by FINN."
) )
if not model.get_initializer(n.input[2]) == 0: if n.op_type == "Quant" and not model.get_initializer(n.input[2]) == 0:
raise ValueError( raise ValueError(
"Only Quant nodes with zero-point == 0 are currently supported." "Only Quant nodes with zero-point == 0 are currently supported."
) )
# Check if the bit width is low enough # Check if the bit width is low enough
bit_width = model.get_initializer(n.input[3]) if n.op_type == "Quant":
bit_width = model.get_initializer(n.input[3])
elif n.op_type == "BinaryQuant":
bit_width = 1.0
else:
raise RuntimeError("Got an unexpected quantizer node type")
if bit_width is None: if bit_width is None:
raise ValueError("Quant nodes must have a static bit width.") raise ValueError("Quant nodes must have a static bit width.")
if bit_width > self.max_multithreshold_bit_width: if bit_width > self.max_multithreshold_bit_width:
......
...@@ -89,6 +89,7 @@ def analysis_testing_for_no_quant_nodes(model): ...@@ -89,6 +89,7 @@ def analysis_testing_for_no_quant_nodes(model):
# ToDo: Add KWS networks, when they are ready to be added to finn-examples. # ToDo: Add KWS networks, when they are ready to be added to finn-examples.
# ToDo: Add RadioML_VGG10, if possible # ToDo: Add RadioML_VGG10, if possible
# This test currently takes about 4 min and 42 seconds
@pytest.mark.parametrize("abits", [1, 2]) @pytest.mark.parametrize("abits", [1, 2])
@pytest.mark.parametrize("wbits", [1, 2]) @pytest.mark.parametrize("wbits", [1, 2])
@pytest.mark.parametrize("model_name", ["TFC", "SFC", "LFC", "CNV", "mobilenet"]) @pytest.mark.parametrize("model_name", ["TFC", "SFC", "LFC", "CNV", "mobilenet"])
...@@ -97,13 +98,9 @@ def test_QONNX_to_FINN(model_name, wbits, abits): ...@@ -97,13 +98,9 @@ def test_QONNX_to_FINN(model_name, wbits, abits):
pytest.skip("No wbits > abits cases at the moment") pytest.skip("No wbits > abits cases at the moment")
if model_name == "LFC" and wbits == 2 and abits == 2: if model_name == "LFC" and wbits == 2 and abits == 2:
pytest.skip("No LFC-w2a2 present at the moment") pytest.skip("No LFC-w2a2 present at the moment")
if model_name == "mobilenet" and wbits < 2 and abits < 2: if model_name == "mobilenet" and (wbits != 2 or abits != 2):
pytest.skip("Mobilenet only runs at W2A2, though it's technically W4A4.") pytest.skip("Mobilenet only runs at W2A2, though it's technically W4A4.")
# ToDo: Remove the following restriction when QONNX supports binary operations.
if wbits == 1 or abits == 1:
pytest.skip("wbits == 1 or abits == 1 is currently not supported by QONNX.")
brev_model, in_shape, input_tensor = get_brev_model_and_sample_inputs( brev_model, in_shape, input_tensor = get_brev_model_and_sample_inputs(
model_name, wbits, abits model_name, wbits, abits
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment