Skip to content
Snippets Groups Projects
Commit 828de10f authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added QONNX_export test to test_brevitas_act_export_qhardtanh_nonscaled test.

parent 0e584d48
No related branches found
No related tags found
No related merge requests found
......@@ -36,11 +36,14 @@ import torch
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantHardTanh
from qonnx.util.cleanup import cleanup as qonnx_cleanup
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
export_onnx_path = "test_brevitas_non_scaled_QuantHardTanh_export.onnx"
......@@ -48,7 +51,10 @@ export_onnx_path = "test_brevitas_non_scaled_QuantHardTanh_export.onnx"
@pytest.mark.parametrize("abits", [1, 2, 4, 8])
@pytest.mark.parametrize("narrow_range", [False, True])
@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)])
def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val):
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_act_export_qhardtanh_nonscaled(
abits, narrow_range, max_val, QONNX_export
):
def get_quant_type(bit_width):
if bit_width is None:
return QuantType.FP
......@@ -69,7 +75,15 @@ def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val):
scaling_impl_type=ScalingImplType.CONST,
narrow_range=narrow_range,
)
bo.export_finn_onnx(b_act, ishape, export_onnx_path)
if QONNX_export:
m_path = export_onnx_path
BrevitasONNXManager.export(b_act, ishape, m_path)
qonnx_cleanup(m_path, out_file=m_path)
model = ModelWrapper(m_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(m_path)
else:
bo.export_finn_onnx(b_act, ishape, export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment