From fd993b99747d065440b4aeaff4fd83e6ef160a94 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 14 May 2020 11:24:01 +0100 Subject: [PATCH] [Test] rename act export tests consistently --- .../brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py | 4 +++- tests/brevitas/test_brevitas_relu_act_export.py | 2 +- tests/brevitas/test_brevitas_scaled_QHardTanh_export.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py index c22f30c6f..b66348a99 100644 --- a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py +++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py @@ -1,3 +1,4 @@ +import os import onnx # noqa import numpy as np import torch @@ -17,7 +18,7 @@ export_onnx_path = "test_act.onnx" @pytest.mark.parametrize("abits", [1, 2, 4, 8]) @pytest.mark.parametrize("narrow_range", [False, True]) @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)]) -def test_brevitas_act_export(abits, narrow_range, max_val): +def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val): def get_quant_type(bit_width): if bit_width is None: return QuantType.FP @@ -50,3 +51,4 @@ def test_brevitas_act_export(abits, narrow_range, max_val): inp_tensor = torch.from_numpy(inp_tensor).float() expected = b_act.forward(inp_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_onnx_path) diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py index 4c450243b..c9d8f2d81 100644 --- a/tests/brevitas/test_brevitas_relu_act_export.py +++ b/tests/brevitas/test_brevitas_relu_act_export.py @@ -20,7 +20,7 @@ export_onnx_path = "test_act.onnx" @pytest.mark.parametrize( "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER] ) -def test_brevitas_relu_act_export(abits, max_val, scaling_impl_type): +def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type): min_val = -1.0 ishape = (1, 15) b_act = QuantReLU( diff --git a/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py index 9af58d4e3..d499f1517 100644 --- a/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py +++ b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py @@ -22,7 +22,9 @@ export_onnx_path = "test_act.onnx" @pytest.mark.parametrize( "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER] ) -def test_brevitas_act_export(abits, narrow_range, min_val, max_val, scaling_impl_type): +def test_brevitas_act_export_qhardtanh_scaled( + abits, narrow_range, min_val, max_val, scaling_impl_type +): def get_quant_type(bit_width): if bit_width is None: return QuantType.FP -- GitLab