From f39486184f73ff88601ec6f47cfbad4ce309f278 Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Fri, 15 Oct 2021 14:20:38 +0100 Subject: [PATCH] Added QONNX_export test to test_brevitas_cnv_export_exec test. --- docker/Dockerfile.finn | 2 +- tests/brevitas/test_brevitas_cnv.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn index 3c165e7f8..d809d99b9 100644 --- a/docker/Dockerfile.finn +++ b/docker/Dockerfile.finn @@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg # git-based Python repo dependencies # these are installed in editable mode for easier co-development -ARG FINN_BASE_COMMIT="22886049c96048d7150efe261a82b9c3f469c100" +ARG FINN_BASE_COMMIT="352cb9c41676fa509f57f20a32c7362c6c09039a" ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042" diff --git a/tests/brevitas/test_brevitas_cnv.py b/tests/brevitas/test_brevitas_cnv.py index 8a1783ae9..78ca36136 100644 --- a/tests/brevitas/test_brevitas_cnv.py +++ b/tests/brevitas/test_brevitas_cnv.py @@ -34,12 +34,15 @@ import brevitas.onnx as bo import numpy as np import os import torch +from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from qonnx.util.cleanup import cleanup as qonnx_cleanup import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper from finn.transformation.fold_constants import FoldConstants from finn.transformation.general import GiveUniqueNodeNames, RemoveStaticGraphInputs from finn.transformation.infer_shapes import InferShapes +from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN from finn.util.test import get_test_model_trained export_onnx_path = "test_brevitas_cnv.onnx" @@ -47,11 +50,20 @@ export_onnx_path = "test_brevitas_cnv.onnx" @pytest.mark.parametrize("abits", [1, 2]) @pytest.mark.parametrize("wbits", [1, 2]) -def test_brevitas_cnv_export_exec(wbits, abits): +@pytest.mark.parametrize("QONNX_export", [False, True]) +def test_brevitas_cnv_export_exec(wbits, abits, QONNX_export): if wbits > abits: pytest.skip("No wbits > abits cases at the moment") cnv = get_test_model_trained("CNV", wbits, abits) - bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) + ishape = (1, 3, 32, 32) + if QONNX_export: + BrevitasONNXManager.export(cnv, ishape, export_onnx_path) + qonnx_cleanup(export_onnx_path, out_file=export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(ConvertQONNXtoFINN()) + model.save(export_onnx_path) + else: + bo.export_finn_onnx(cnv, ishape, export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(GiveUniqueNodeNames()) model = model.transform(InferShapes()) -- GitLab