diff --git a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py new file mode 100644 index 0000000000000000000000000000000000000000..a50e8eae74b5d148f6d8d7ae67433ee51e094217 --- /dev/null +++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py @@ -0,0 +1,44 @@ +import onnx # noqa +import numpy as np +import torch +import brevitas.onnx as bo +from brevitas.nn import QuantHardTanh +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.scaling import ScalingImplType +from models.common import get_quant_type +import pytest +from finn.core.modelwrapper import ModelWrapper +import finn.core.onnx_exec as oxe +from finn.transformation.infer_shapes import InferShapes + +export_onnx_path = "test_act.onnx" + + +@pytest.mark.parametrize("abits", [1, 2, 4, 8]) +@pytest.mark.parametrize("narrow_range", [False, True]) +@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)]) +def test_brevitas_act_export(abits, narrow_range, max_val): + act_quant_type = get_quant_type(abits) + min_val = -1.0 + ishape = (1, 10) + b_act = QuantHardTanh( + bit_width=abits, + quant_type=act_quant_type, + max_val=max_val, + min_val=min_val, + restrict_scaling_type=RestrictValueType.LOG_FP, + scaling_impl_type=ScalingImplType.CONST, + narrow_range=narrow_range, + ) + bo.export_finn_onnx(b_act, ishape, export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( + np.float32 + ) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + expected = b_act.forward(inp_tensor).detach().numpy() + assert np.isclose(produced, expected, atol=1e-3).all() diff --git a/tests/brevitas/test_brevitas_act_export.py b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py similarity index 98% rename from tests/brevitas/test_brevitas_act_export.py rename to tests/brevitas/test_brevitas_scaled_QHardTanh_export.py index 77ad1fb1cb2f9cb116ad2f8961b0878d95d73d34..02d7e27c8a94372ce18a943668cea1daf27c8cfa 100644 --- a/tests/brevitas/test_brevitas_act_export.py +++ b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py @@ -14,7 +14,7 @@ from finn.transformation.infer_shapes import InferShapes export_onnx_path = "test_act.onnx" -@pytest.mark.parametrize("abits", [1, 2, 4, 8]) +@pytest.mark.parametrize("abits", [2, 4, 8]) @pytest.mark.parametrize("narrow_range", [False, True]) @pytest.mark.parametrize("min_val", [-1.0, -(1 - 2 ** (-7)), -2]) @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7), 2])