Skip to content
Snippets Groups Projects
Commit ae210a2d authored by auphelia's avatar auphelia
Browse files

[Tests] Update export fct in conversion to hls layer tests

parent 1b1ee9b3
No related branches found
No related tags found
No related merge requests found
...@@ -30,9 +30,10 @@ import pkg_resources as pk ...@@ -30,9 +30,10 @@ import pkg_resources as pk
import pytest import pytest
import brevitas.onnx as bo
import numpy as np import numpy as np
import os import os
import torch
from brevitas.export import export_finn_onnx
from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount from qonnx.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
...@@ -61,7 +62,7 @@ export_onnx_path_cnv = "test_convert_to_hls_layers_cnv.onnx" ...@@ -61,7 +62,7 @@ export_onnx_path_cnv = "test_convert_to_hls_layers_cnv.onnx"
@pytest.mark.parametrize("fused_activation", [True, False]) @pytest.mark.parametrize("fused_activation", [True, False])
def test_convert_to_hls_layers_cnv_w1a1(fused_activation): def test_convert_to_hls_layers_cnv_w1a1(fused_activation):
cnv = get_test_model_trained("CNV", 1, 1) cnv = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv) export_finn_onnx(cnv, torch.randn(1, 3, 32, 32), export_onnx_path_cnv)
model = ModelWrapper(export_onnx_path_cnv) model = ModelWrapper(export_onnx_path_cnv)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(FoldConstants()) model = model.transform(FoldConstants())
......
...@@ -28,12 +28,12 @@ ...@@ -28,12 +28,12 @@
import pytest import pytest
import brevitas.onnx as bo
import numpy as np import numpy as np
import onnx import onnx
import onnx.numpy_helper as nph import onnx.numpy_helper as nph
import os import os
import torch import torch
from brevitas.export import export_finn_onnx
from pkgutil import get_data from pkgutil import get_data
from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp from qonnx.custom_op.registry import getCustomOp
...@@ -59,7 +59,7 @@ export_onnx_path = "test_convert_to_hls_layers_fc.onnx" ...@@ -59,7 +59,7 @@ export_onnx_path = "test_convert_to_hls_layers_fc.onnx"
@pytest.mark.vivado @pytest.mark.vivado
def test_convert_to_hls_layers_tfc_w1a1(): def test_convert_to_hls_layers_tfc_w1a1():
tfc = get_test_model_trained("TFC", 1, 1) tfc = get_test_model_trained("TFC", 1, 1)
bo.export_finn_onnx(tfc, (1, 1, 28, 28), export_onnx_path) export_finn_onnx(tfc, torch.randn(1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path) model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(FoldConstants()) model = model.transform(FoldConstants())
...@@ -130,7 +130,7 @@ def test_convert_to_hls_layers_tfc_w1a1(): ...@@ -130,7 +130,7 @@ def test_convert_to_hls_layers_tfc_w1a1():
@pytest.mark.vivado @pytest.mark.vivado
def test_convert_to_hls_layers_tfc_w1a2(): def test_convert_to_hls_layers_tfc_w1a2():
tfc = get_test_model_trained("TFC", 1, 2) tfc = get_test_model_trained("TFC", 1, 2)
bo.export_finn_onnx(tfc, (1, 1, 28, 28), export_onnx_path) export_finn_onnx(tfc, torch.randn(1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path) model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(FoldConstants()) model = model.transform(FoldConstants())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment