Skip to content
Snippets Groups Projects
Commit f3948618 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added QONNX_export test to test_brevitas_cnv_export_exec test.

parent 078a4d74
No related branches found
No related tags found
No related merge requests found
......@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies
# these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="22886049c96048d7150efe261a82b9c3f469c100"
ARG FINN_BASE_COMMIT="352cb9c41676fa509f57f20a32c7362c6c09039a"
ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
......
......@@ -34,12 +34,15 @@ import brevitas.onnx as bo
import numpy as np
import os
import torch
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from qonnx.util.cleanup import cleanup as qonnx_cleanup
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveUniqueNodeNames, RemoveStaticGraphInputs
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.test import get_test_model_trained
export_onnx_path = "test_brevitas_cnv.onnx"
......@@ -47,11 +50,20 @@ export_onnx_path = "test_brevitas_cnv.onnx"
@pytest.mark.parametrize("abits", [1, 2])
@pytest.mark.parametrize("wbits", [1, 2])
def test_brevitas_cnv_export_exec(wbits, abits):
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_cnv_export_exec(wbits, abits, QONNX_export):
if wbits > abits:
pytest.skip("No wbits > abits cases at the moment")
cnv = get_test_model_trained("CNV", wbits, abits)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
ishape = (1, 3, 32, 32)
if QONNX_export:
BrevitasONNXManager.export(cnv, ishape, export_onnx_path)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(export_onnx_path)
else:
bo.export_finn_onnx(cnv, ishape, export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(InferShapes())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment