Skip to content
Snippets Groups Projects
Commit fd993b99 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] rename act export tests consistently

parent 5f99ef71
No related branches found
No related tags found
No related merge requests found
import os
import onnx # noqa
import numpy as np
import torch
......@@ -17,7 +18,7 @@ export_onnx_path = "test_act.onnx"
@pytest.mark.parametrize("abits", [1, 2, 4, 8])
@pytest.mark.parametrize("narrow_range", [False, True])
@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)])
def test_brevitas_act_export(abits, narrow_range, max_val):
def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val):
def get_quant_type(bit_width):
if bit_width is None:
return QuantType.FP
......@@ -50,3 +51,4 @@ def test_brevitas_act_export(abits, narrow_range, max_val):
inp_tensor = torch.from_numpy(inp_tensor).float()
expected = b_act.forward(inp_tensor).detach().numpy()
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
......@@ -20,7 +20,7 @@ export_onnx_path = "test_act.onnx"
@pytest.mark.parametrize(
"scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
)
def test_brevitas_relu_act_export(abits, max_val, scaling_impl_type):
def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type):
min_val = -1.0
ishape = (1, 15)
b_act = QuantReLU(
......
......@@ -22,7 +22,9 @@ export_onnx_path = "test_act.onnx"
@pytest.mark.parametrize(
"scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
)
def test_brevitas_act_export(abits, narrow_range, min_val, max_val, scaling_impl_type):
def test_brevitas_act_export_qhardtanh_scaled(
abits, narrow_range, min_val, max_val, scaling_impl_type
):
def get_quant_type(bit_width):
if bit_width is None:
return QuantType.FP
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment