From ae210a2d53b563a56447c387873bf29afa627bcd Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Tue, 21 Feb 2023 16:27:21 +0000
Subject: [PATCH] [Tests] Update export fct in conversion to hls layer tests

---
 tests/fpgadataflow/test_convert_to_hls_layers_cnv.py | 5 +++--
 tests/fpgadataflow/test_convert_to_hls_layers_fc.py  | 6 +++---
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py b/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py
index 9997f2843..73721b6cc 100644
--- a/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py
+++ b/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py
@@ -30,9 +30,10 @@ import pkg_resources as pk
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import os
+import torch
+from brevitas.export import export_finn_onnx
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.custom_op.registry import getCustomOp
 from qonnx.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
@@ -61,7 +62,7 @@ export_onnx_path_cnv = "test_convert_to_hls_layers_cnv.onnx"
 @pytest.mark.parametrize("fused_activation", [True, False])
 def test_convert_to_hls_layers_cnv_w1a1(fused_activation):
     cnv = get_test_model_trained("CNV", 1, 1)
-    bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv)
+    export_finn_onnx(cnv, torch.randn(1, 3, 32, 32), export_onnx_path_cnv)
     model = ModelWrapper(export_onnx_path_cnv)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
diff --git a/tests/fpgadataflow/test_convert_to_hls_layers_fc.py b/tests/fpgadataflow/test_convert_to_hls_layers_fc.py
index fd4e3679d..5a45638ba 100644
--- a/tests/fpgadataflow/test_convert_to_hls_layers_fc.py
+++ b/tests/fpgadataflow/test_convert_to_hls_layers_fc.py
@@ -28,12 +28,12 @@
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import onnx
 import onnx.numpy_helper as nph
 import os
 import torch
+from brevitas.export import export_finn_onnx
 from pkgutil import get_data
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.custom_op.registry import getCustomOp
@@ -59,7 +59,7 @@ export_onnx_path = "test_convert_to_hls_layers_fc.onnx"
 @pytest.mark.vivado
 def test_convert_to_hls_layers_tfc_w1a1():
     tfc = get_test_model_trained("TFC", 1, 1)
-    bo.export_finn_onnx(tfc, (1, 1, 28, 28), export_onnx_path)
+    export_finn_onnx(tfc, torch.randn(1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
@@ -130,7 +130,7 @@ def test_convert_to_hls_layers_tfc_w1a1():
 @pytest.mark.vivado
 def test_convert_to_hls_layers_tfc_w1a2():
     tfc = get_test_model_trained("TFC", 1, 2)
-    bo.export_finn_onnx(tfc, (1, 1, 28, 28), export_onnx_path)
+    export_finn_onnx(tfc, torch.randn(1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
-- 
GitLab