Skip to content
Snippets Groups Projects
Commit 44892ae4 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

Scaled and non-scaled QuantHardTanh activation tests separated. Non-scaled...

Scaled and non-scaled QuantHardTanh activation tests separated. Non-scaled test includes cases with scale parameter == 1. In scaled test, bipolar activation is excluded as it's not currently supported
parent d216e361
No related branches found
No related tags found
No related merge requests found
import onnx # noqa
import numpy as np
import torch
import brevitas.onnx as bo
from brevitas.nn import QuantHardTanh
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from models.common import get_quant_type
import pytest
from finn.core.modelwrapper import ModelWrapper
import finn.core.onnx_exec as oxe
from finn.transformation.infer_shapes import InferShapes
export_onnx_path = "test_act.onnx"
@pytest.mark.parametrize("abits", [1, 2, 4, 8])
@pytest.mark.parametrize("narrow_range", [False, True])
@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)])
def test_brevitas_act_export(abits, narrow_range, max_val):
act_quant_type = get_quant_type(abits)
min_val = -1.0
ishape = (1, 10)
b_act = QuantHardTanh(
bit_width=abits,
quant_type=act_quant_type,
max_val=max_val,
min_val=min_val,
restrict_scaling_type=RestrictValueType.LOG_FP,
scaling_impl_type=ScalingImplType.CONST,
narrow_range=narrow_range,
)
bo.export_finn_onnx(b_act, ishape, export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
np.float32
)
idict = {model.graph.input[0].name: inp_tensor}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
inp_tensor = torch.from_numpy(inp_tensor).float()
expected = b_act.forward(inp_tensor).detach().numpy()
assert np.isclose(produced, expected, atol=1e-3).all()
......@@ -14,7 +14,7 @@ from finn.transformation.infer_shapes import InferShapes
export_onnx_path = "test_act.onnx"
@pytest.mark.parametrize("abits", [1, 2, 4, 8])
@pytest.mark.parametrize("abits", [2, 4, 8])
@pytest.mark.parametrize("narrow_range", [False, True])
@pytest.mark.parametrize("min_val", [-1.0, -(1 - 2 ** (-7)), -2])
@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7), 2])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment