Skip to content
Snippets Groups Projects
Commit 1b1ee9b3 authored by auphelia's avatar auphelia
Browse files

[Tests] Update export fct in transformation tests

parent 15a35526
No related branches found
No related tags found
No related merge requests found
......@@ -28,10 +28,11 @@
import pytest
import brevitas.onnx as bo
import onnx
import onnx.numpy_helper as nph
import os
import torch
from brevitas.export import export_finn_onnx
from pkgutil import get_data
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.fold_constants import FoldConstants
......@@ -47,7 +48,7 @@ export_onnx_path = "test_sign_to_thres.onnx"
@pytest.mark.streamline
def test_sign_to_thres():
lfc = get_test_model_trained("LFC", 1, 1)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
export_finn_onnx(lfc, torch.randn(1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......
......@@ -30,8 +30,9 @@ import pkg_resources as pk
import pytest
import brevitas.onnx as bo
import numpy as np
import torch
from brevitas.export import export_finn_onnx
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import (
......@@ -63,7 +64,7 @@ def test_streamline_cnv(size, wbits, abits):
nname = "%s_%dW%dA" % (size, wbits, abits)
finn_onnx = export_onnx_path + "/%s.onnx" % nname
fc = get_test_model_trained(size, wbits, abits)
bo.export_finn_onnx(fc, (1, 3, 32, 32), finn_onnx)
export_finn_onnx(fc, torch.randn(1, 3, 32, 32), finn_onnx)
model = ModelWrapper(finn_onnx)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......
......@@ -28,10 +28,11 @@
import pytest
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import torch
from brevitas.export import export_finn_onnx
from pkgutil import get_data
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.fold_constants import FoldConstants
......@@ -66,7 +67,7 @@ def test_streamline_fc(size, wbits, abits):
nname = "%s_%dW%dA" % (size, wbits, abits)
finn_onnx = export_onnx_path + "/%s.onnx" % nname
fc = get_test_model_trained(size, wbits, abits)
bo.export_finn_onnx(fc, (1, 1, 28, 28), finn_onnx)
export_finn_onnx(fc, torch.randn(1, 1, 28, 28), finn_onnx)
model = ModelWrapper(finn_onnx)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......
......@@ -30,11 +30,12 @@ import pkg_resources as pk
import pytest
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import os
import torch
from brevitas.export import export_finn_onnx
from pkgutil import get_data
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine
......@@ -50,7 +51,7 @@ export_onnx_path = "test_output_bn2affine.onnx"
@pytest.mark.transform
def test_batchnorm_to_affine_cnv_w1a1():
lfc = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(lfc, (1, 3, 32, 32), export_onnx_path)
export_finn_onnx(lfc, torch.randn(1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......@@ -75,7 +76,7 @@ def test_batchnorm_to_affine_cnv_w1a1():
@pytest.mark.transform
def test_batchnorm_to_affine_lfc_w1a1():
lfc = get_test_model_trained("LFC", 1, 1)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
export_finn_onnx(lfc, torch.randn(1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......
......@@ -28,9 +28,10 @@
import pytest
import brevitas.onnx as bo
import os
import qonnx.core.data_layout as DataLayout
import torch
from brevitas.export import export_finn_onnx
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
from qonnx.transformation.fold_constants import FoldConstants
......@@ -51,7 +52,7 @@ export_onnx_path_cnv = "test_infer_data_layouts.onnx"
@pytest.mark.transform
def test_infer_data_layouts_cnv():
cnv = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv)
export_finn_onnx(cnv, torch.randn(1, 3, 32, 32), export_onnx_path_cnv)
model = ModelWrapper(export_onnx_path_cnv)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......
......@@ -28,8 +28,9 @@
import pytest
import brevitas.onnx as bo
import os
import torch
from brevitas.export import export_finn_onnx
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.fold_constants import FoldConstants
......@@ -45,7 +46,7 @@ export_onnx_path = "test_infer_datatypes.onnx"
@pytest.mark.transform
def test_infer_datatypes_lfc():
lfc = get_test_model_trained("LFC", 1, 1)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
export_finn_onnx(lfc, torch.randn(1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......
......@@ -31,12 +31,11 @@ import pkg_resources as pk
import pytest
import brevitas.export.onnx.generic as b_onnx
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import torch
from brevitas.export import export_finn_onnx, export_qonnx
from pkgutil import get_data
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.fold_constants import FoldConstants
......@@ -117,8 +116,10 @@ def test_QONNX_to_FINN(model_name, wbits, abits):
torch_input_tensor = torch.from_numpy(input_tensor).float()
brev_output = brev_model.forward(torch_input_tensor).detach().numpy()
# Get "clean" FINN model and it's output
_ = bo.export_finn_onnx(brev_model, in_shape, finn_base_path.format("raw"))
# Get "clean" FINN model and its output
_ = export_finn_onnx(
brev_model, torch.randn(in_shape), finn_base_path.format("raw")
)
model = ModelWrapper(finn_base_path.format("raw"))
model = model.transform(GiveUniqueNodeNames())
model = model.transform(InferShapes())
......@@ -137,10 +138,7 @@ def test_QONNX_to_FINN(model_name, wbits, abits):
).all(), "The output of the Brevitas model and the FINN model should match."
# Get the equivalent QONNX model
b_onnx.function.DOMAIN_STRING = "qonnx.custom_op.general"
_ = b_onnx.manager.BrevitasONNXManager.export(
brev_model, in_shape, qonnx_base_path.format("raw")
)
_ = export_qonnx(brev_model, torch.randn(in_shape), qonnx_base_path.format("raw"))
cleanup(qonnx_base_path.format("raw"), out_file=qonnx_base_path.format("clean"))
# Compare output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment