diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 9819086d826a51d1df5240d88c4fda8513cc9ba6..bbe5e1a0e319a8f62e9a1bcd4f0857f36295049e 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -286,6 +286,7 @@ class QuantReluHandler(QuantActBaseHandler): def valid_predecessor_op_types(self): return [ "Relu", + "Selu", ] def _check_compatibility(self): @@ -293,16 +294,19 @@ class QuantReluHandler(QuantActBaseHandler): q_inst = getCustomOp(self._q_node) narrow = q_inst.get_nodeattr("narrow") signed = q_inst.get_nodeattr("signed") - if signed or narrow: - raise ValueError( - "FINN only supports unsigned and non-narrow Quant nodes " - "for Relu activations." - ) if not self._model.get_initializer(self._q_node.input[2]) == 0: raise ValueError( "Only Quant nodes with zero-point == 0 " "are currently supported for ReLu activations." ) + act_node = self._model.find_direct_predecessors(self._q_node) + act_node = act_node[0] + if act_node.op_type == "Relu": + if signed or narrow: + raise ValueError( + "FINN only supports unsigned and non-narrow Quant nodes " + "for Relu activations." + ) elif self._q_node.op_type == "BipolarQuant": return else: @@ -312,7 +316,31 @@ class QuantReluHandler(QuantActBaseHandler): # No bias allowed for Relu activations, see: https://github.com/Xilinx/ # brevitas/blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # export/onnx/finn/handler/act.py#L48 - bias = np.array([0.0], dtype=np_default_dtype) + act_node = self._model.find_direct_predecessors(self._q_node) + act_node = act_node[0] + if act_node.op_type == "Relu": + bias = np.array([0.0], dtype=np_default_dtype) + elif act_node.op_type == "Selu": + # Gather parameters + q_inst = getCustomOp(self._q_node) + if self._q_node.op_type == "Quant": + bit_width = self._model.get_initializer(self._q_node.input[3]) + narrow = q_inst.get_nodeattr("narrow") + elif self._q_node.op_type == "BipolarQuant": + bit_width = 1.0 + else: + raise RuntimeError("Got an unexpected quantizer node type") + # Calculate bias, see: https://github.com/Xilinx/brevitas/blob/ + # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ + # onnx/finn/handler/act.py#L64 + if bit_width == 1.0: + bias = np.array([-0.5], dtype=np_default_dtype) + else: + if narrow: + min_non_scaled_val = -(2 ** (bit_width - 1) - 1) + else: + min_non_scaled_val = -(2 ** (bit_width - 1)) + bias = np.array([min_non_scaled_val], dtype=np_default_dtype) return bias def _calculate_thresholds(self): @@ -326,24 +354,53 @@ class QuantReluHandler(QuantActBaseHandler): quant_scale = self._model.get_initializer(self._q_node.input[1]).astype( np.float32 ) - # q_inst = getCustomOp(self._q_node) - # narrow = q_inst.get_nodeattr("narrow") + act_node = self._model.find_direct_predecessors(self._q_node) + act_node = act_node[0] + if act_node.op_type == "Relu": - # Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/ - # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ - # onnx/finn/handler/act.py#L21 - num_distinct_values = 2**bit_width - num_thresholds = int(num_distinct_values - 1) - flat_scale = quant_scale.flatten().astype(np.float32) - num_scale_channels = flat_scale.shape[0] - step = np.abs(flat_scale).astype(np.float32) - min_threshold = step / 2 - thresholds = np.empty( - (num_scale_channels, num_thresholds), dtype=np_default_dtype - ) - for c in range(num_scale_channels): - for t in range(num_thresholds): - thresholds[c][t] = min_threshold[c] + step[c] * t + # Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/ + # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ + # onnx/finn/handler/act.py#L21 + num_distinct_values = 2**bit_width + num_thresholds = int(num_distinct_values - 1) + flat_scale = quant_scale.flatten().astype(np.float32) + num_scale_channels = flat_scale.shape[0] + step = np.abs(flat_scale).astype(np.float32) + min_threshold = step / 2 + thresholds = np.empty( + (num_scale_channels, num_thresholds), dtype=np_default_dtype + ) + for c in range(num_scale_channels): + for t in range(num_thresholds): + thresholds[c][t] = min_threshold[c] + step[c] * t + + elif act_node.op_type == "Selu": + q_inst = getCustomOp(self._q_node) + narrow = q_inst.get_nodeattr("narrow") + if narrow: + num_distinct_values = 2**bit_width - 1 + else: + num_distinct_values = 2**bit_width + + num_thresholds = int(num_distinct_values - 1) + flat_scale = quant_scale.flatten().astype(np.float32) + num_scale_channels = flat_scale.shape[0] + scale = np.abs(flat_scale).astype(np.float32) + half_scale = scale / 2 + # alpha and lambda + # from https://pytorch.org/docs/stable/generated/torch.nn.SELU.html + alpha = 1.6732632423543772848170429916717 + selu_scale = 1.0507009873554804934193349852946 + thresholds = np.empty( + (num_scale_channels, num_thresholds), dtype=np_default_dtype + ) + for c in range(num_scale_channels): + for t in range(num_thresholds): + step = -1.0 + half_scale + scale[c] * t + if step <= 0: + thresholds[c][t] = np.log(step / (alpha * selu_scale) + 1) + else: + thresholds[c][t] = step / selu_scale # ToDo: The index 1 needs to be changed to -1 for the channels last format num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] @@ -371,10 +428,10 @@ class QuantReluHandler(QuantActBaseHandler): "the Quant node must exist." ) act_node = act_node[0] - if not act_node.op_type == "Relu": + if act_node.op_type not in self.valid_predecessor_op_types(): raise RuntimeError( - "The predecesor of the Quant node must be Relu for handling " - "of Relu activations." + "The predecesor of the Quant node must be Relu or Selu for handling " + "of activations." ) # Reroute upstream tensor diff --git a/tests/brevitas/test_brevitas_selu_act_export.py b/tests/brevitas/test_brevitas_selu_act_export.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4807c5d7286c265a856d73e9aaa886f342555e --- /dev/null +++ b/tests/brevitas/test_brevitas_selu_act_export.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +import onnx # noqa +import os +import torch +from brevitas.export import export_qonnx +from brevitas.nn import QuantIdentity +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.basic import get_preferred_onnx_opset +from qonnx.util.cleanup import cleanup as qonnx_cleanup + +import finn.core.onnx_exec as oxe +from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN + + +@pytest.mark.brevitas_export +@pytest.mark.parametrize("abits", [2, 4, 8]) +@pytest.mark.parametrize("ishape", [(1, 15), (1, 32, 1, 1)]) +@pytest.mark.parametrize("narrow", [True, False]) +def test_brevitas_act_export_selu(abits, ishape, narrow): + export_path = "test_brevitas_selu_act_export_%s.onnx" % str(abits) + b_act = torch.nn.Sequential( + torch.nn.SELU(), QuantIdentity(bit_width=abits, narrow=narrow) + ) + + export_qonnx( + b_act, + torch.randn(ishape), + export_path, + opset_version=get_preferred_onnx_opset(), + ) + qonnx_cleanup(export_path, out_file=export_path) + model = ModelWrapper(export_path) + model = model.transform(ConvertQONNXtoFINN()) + + inp_tensor = np.random.uniform(low=-1.0, high=6.0, size=ishape).astype(np.float32) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + b_act.eval() + expected = b_act.forward(inp_tensor).detach().numpy() + + assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_path)