From fd993b99747d065440b4aeaff4fd83e6ef160a94 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 14 May 2020 11:24:01 +0100
Subject: [PATCH] [Test] rename act export tests consistently

---
 .../brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py | 4 +++-
 tests/brevitas/test_brevitas_relu_act_export.py               | 2 +-
 tests/brevitas/test_brevitas_scaled_QHardTanh_export.py       | 4 +++-
 3 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
index c22f30c6f..b66348a99 100644
--- a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
+++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
@@ -1,3 +1,4 @@
+import os
 import onnx  # noqa
 import numpy as np
 import torch
@@ -17,7 +18,7 @@ export_onnx_path = "test_act.onnx"
 @pytest.mark.parametrize("abits", [1, 2, 4, 8])
 @pytest.mark.parametrize("narrow_range", [False, True])
 @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)])
-def test_brevitas_act_export(abits, narrow_range, max_val):
+def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val):
     def get_quant_type(bit_width):
         if bit_width is None:
             return QuantType.FP
@@ -50,3 +51,4 @@ def test_brevitas_act_export(abits, narrow_range, max_val):
     inp_tensor = torch.from_numpy(inp_tensor).float()
     expected = b_act.forward(inp_tensor).detach().numpy()
     assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)
diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py
index 4c450243b..c9d8f2d81 100644
--- a/tests/brevitas/test_brevitas_relu_act_export.py
+++ b/tests/brevitas/test_brevitas_relu_act_export.py
@@ -20,7 +20,7 @@ export_onnx_path = "test_act.onnx"
 @pytest.mark.parametrize(
     "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
 )
-def test_brevitas_relu_act_export(abits, max_val, scaling_impl_type):
+def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type):
     min_val = -1.0
     ishape = (1, 15)
     b_act = QuantReLU(
diff --git a/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py
index 9af58d4e3..d499f1517 100644
--- a/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py
+++ b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py
@@ -22,7 +22,9 @@ export_onnx_path = "test_act.onnx"
 @pytest.mark.parametrize(
     "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
 )
-def test_brevitas_act_export(abits, narrow_range, min_val, max_val, scaling_impl_type):
+def test_brevitas_act_export_qhardtanh_scaled(
+    abits, narrow_range, min_val, max_val, scaling_impl_type
+):
     def get_quant_type(bit_width):
         if bit_width is None:
             return QuantType.FP
-- 
GitLab