From 828de10fa8afbcb54596058205050edbeaa59f83 Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Fri, 15 Oct 2021 14:24:04 +0100 Subject: [PATCH] Added QONNX_export test to test_brevitas_act_export_qhardtanh_nonscaled test. --- ...brevitas_non_scaled_QuantHardTanh_export.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py index 6ddf71a5c..b530b4bd8 100644 --- a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py +++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py @@ -36,11 +36,14 @@ import torch from brevitas.core.quant import QuantType from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType +from brevitas.export.onnx.generic.manager import BrevitasONNXManager from brevitas.nn import QuantHardTanh +from qonnx.util.cleanup import cleanup as qonnx_cleanup import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper from finn.transformation.infer_shapes import InferShapes +from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN export_onnx_path = "test_brevitas_non_scaled_QuantHardTanh_export.onnx" @@ -48,7 +51,10 @@ export_onnx_path = "test_brevitas_non_scaled_QuantHardTanh_export.onnx" @pytest.mark.parametrize("abits", [1, 2, 4, 8]) @pytest.mark.parametrize("narrow_range", [False, True]) @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)]) -def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val): +@pytest.mark.parametrize("QONNX_export", [False, True]) +def test_brevitas_act_export_qhardtanh_nonscaled( + abits, narrow_range, max_val, QONNX_export +): def get_quant_type(bit_width): if bit_width is None: return QuantType.FP @@ -69,7 +75,15 @@ def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val): scaling_impl_type=ScalingImplType.CONST, narrow_range=narrow_range, ) - bo.export_finn_onnx(b_act, ishape, export_onnx_path) + if QONNX_export: + m_path = export_onnx_path + BrevitasONNXManager.export(b_act, ishape, m_path) + qonnx_cleanup(m_path, out_file=m_path) + model = ModelWrapper(m_path) + model = model.transform(ConvertQONNXtoFINN()) + model.save(m_path) + else: + bo.export_finn_onnx(b_act, ishape, export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( -- GitLab