Skip to content
Snippets Groups Projects
Commit ae210a2d authored by auphelia's avatar auphelia
Browse files

[Tests] Update export fct in conversion to hls layer tests

parent 1b1ee9b3
No related branches found
No related tags found
No related merge requests found
......@@ -30,9 +30,10 @@ import pkg_resources as pk
import pytest
import brevitas.onnx as bo
import numpy as np
import os
import torch
from brevitas.export import export_finn_onnx
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
......@@ -61,7 +62,7 @@ export_onnx_path_cnv = "test_convert_to_hls_layers_cnv.onnx"
@pytest.mark.parametrize("fused_activation", [True, False])
def test_convert_to_hls_layers_cnv_w1a1(fused_activation):
cnv = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv)
export_finn_onnx(cnv, torch.randn(1, 3, 32, 32), export_onnx_path_cnv)
model = ModelWrapper(export_onnx_path_cnv)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......
......@@ -28,12 +28,12 @@
import pytest
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import os
import torch
from brevitas.export import export_finn_onnx
from pkgutil import get_data
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp
......@@ -59,7 +59,7 @@ export_onnx_path = "test_convert_to_hls_layers_fc.onnx"
@pytest.mark.vivado
def test_convert_to_hls_layers_tfc_w1a1():
tfc = get_test_model_trained("TFC", 1, 1)
bo.export_finn_onnx(tfc, (1, 1, 28, 28), export_onnx_path)
export_finn_onnx(tfc, torch.randn(1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......@@ -130,7 +130,7 @@ def test_convert_to_hls_layers_tfc_w1a1():
@pytest.mark.vivado
def test_convert_to_hls_layers_tfc_w1a2():
tfc = get_test_model_trained("TFC", 1, 2)
bo.export_finn_onnx(tfc, (1, 1, 28, 28), export_onnx_path)
export_finn_onnx(tfc, torch.randn(1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment