From 285b9933410a0d2ef09315b69a33d3da5b11b893 Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 22 Feb 2023 10:54:57 +0000 Subject: [PATCH] [Tests] Update export fct in brevitas export tests --- .../brevitas/test_brevitas_avg_pool_export.py | 9 ++- tests/brevitas/test_brevitas_cnv.py | 7 +-- tests/brevitas/test_brevitas_debug.py | 6 +- tests/brevitas/test_brevitas_fc.py | 7 +-- tests/brevitas/test_brevitas_mobilenet.py | 7 +-- ...revitas_non_scaled_quanthardtanh_export.py | 7 +-- tests/brevitas/test_brevitas_qconv2d.py | 7 +-- tests/brevitas/test_brevitas_qlinear.py | 7 +-- .../brevitas/test_brevitas_relu_act_export.py | 55 +++++++------------ .../test_brevitas_scaled_qhardtanh_export.py | 7 +-- .../test_brevitas_validate_mobilenet.py | 5 +- 11 files changed, 52 insertions(+), 72 deletions(-) diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index 669601ecb..9c3591036 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -30,8 +30,7 @@ import pytest import numpy as np import os import torch -from brevitas.export import FINNManager -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from brevitas.nn import QuantAvgPool2d from brevitas.quant_tensor import QuantTensor from qonnx.core.datatype import DataType @@ -97,14 +96,14 @@ def test_brevitas_avg_pool_export( # export if QONNX_export: - BrevitasONNXManager.export( + export_qonnx( quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor, ) model = ModelWrapper(export_onnx_path) - # Statically set the additional inputs generated by the BrevitasONNXManager + # Statically set the additional inputs generated by the Brevitas ONNX export model.graph.input.remove(model.graph.input[3]) model.graph.input.remove(model.graph.input[2]) model.graph.input.remove(model.graph.input[1]) @@ -118,7 +117,7 @@ def test_brevitas_avg_pool_export( model = model.transform(ConvertQONNXtoFINN()) model.save(export_onnx_path) else: - FINNManager.export( + export_finn_onnx( quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor ) model = ModelWrapper(export_onnx_path) diff --git a/tests/brevitas/test_brevitas_cnv.py b/tests/brevitas/test_brevitas_cnv.py index 62aab2e3c..1a9681510 100644 --- a/tests/brevitas/test_brevitas_cnv.py +++ b/tests/brevitas/test_brevitas_cnv.py @@ -30,11 +30,10 @@ import pkg_resources as pk import pytest -import brevitas.onnx as bo import numpy as np import os import torch -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.fold_constants import FoldConstants from qonnx.transformation.general import GiveUniqueNodeNames, RemoveStaticGraphInputs @@ -58,13 +57,13 @@ def test_brevitas_cnv_export_exec(wbits, abits, QONNX_export): cnv = get_test_model_trained("CNV", wbits, abits) ishape = (1, 3, 32, 32) if QONNX_export: - BrevitasONNXManager.export(cnv, ishape, export_onnx_path) + export_qonnx(cnv, torch.randn(ishape), export_onnx_path) qonnx_cleanup(export_onnx_path, out_file=export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(ConvertQONNXtoFINN()) model.save(export_onnx_path) else: - bo.export_finn_onnx(cnv, ishape, export_onnx_path) + export_finn_onnx(cnv, torch.randn(ishape), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(GiveUniqueNodeNames()) model = model.transform(InferShapes()) diff --git a/tests/brevitas/test_brevitas_debug.py b/tests/brevitas/test_brevitas_debug.py index 181d610ff..547c026e2 100644 --- a/tests/brevitas/test_brevitas_debug.py +++ b/tests/brevitas/test_brevitas_debug.py @@ -34,7 +34,7 @@ import onnx import onnx.numpy_helper as nph import os import torch -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from pkgutil import get_data from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.fold_constants import FoldConstants @@ -58,7 +58,7 @@ def test_brevitas_debug(QONNX_export, QONNX_FINN_conversion): ishape = (1, 1, 28, 28) if QONNX_export: dbg_hook = bo.enable_debug(fc, proxy_level=True) - BrevitasONNXManager.export(fc, ishape, finn_onnx) + export_qonnx(fc, torch.randn(ishape), finn_onnx) # DebugMarkers have the brevitas.onnx domain, so that needs adjusting model = ModelWrapper(finn_onnx) dbg_nodes = model.get_nodes_by_op_type("DebugMarker") @@ -72,7 +72,7 @@ def test_brevitas_debug(QONNX_export, QONNX_FINN_conversion): model.save(finn_onnx) else: dbg_hook = bo.enable_debug(fc) - bo.export_finn_onnx(fc, ishape, finn_onnx) + export_finn_onnx(fc, torch.randn(ishape), finn_onnx) model = ModelWrapper(finn_onnx) # DebugMarkers have the brevitas.onnx domain, so that needs adjusting # ToDo: We should probably have transformation pass, which does this diff --git a/tests/brevitas/test_brevitas_fc.py b/tests/brevitas/test_brevitas_fc.py index 211fdb629..3aaa96f9a 100644 --- a/tests/brevitas/test_brevitas_fc.py +++ b/tests/brevitas/test_brevitas_fc.py @@ -28,12 +28,11 @@ import pytest -import brevitas.onnx as bo import numpy as np import onnx import onnx.numpy_helper as nph import torch -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from pkgutil import get_data from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.fold_constants import FoldConstants @@ -68,13 +67,13 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, QONNX_export): fc = get_test_model_trained(size, wbits, abits) ishape = (1, 1, 28, 28) if QONNX_export: - BrevitasONNXManager.export(fc, ishape, finn_onnx) + export_qonnx(fc, torch.randn(ishape), finn_onnx) qonnx_cleanup(finn_onnx, out_file=finn_onnx) model = ModelWrapper(finn_onnx) model = model.transform(ConvertQONNXtoFINN()) model.save(finn_onnx) else: - bo.export_finn_onnx(fc, ishape, finn_onnx) + export_finn_onnx(fc, torch.randn(ishape), finn_onnx) model = ModelWrapper(finn_onnx) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) diff --git a/tests/brevitas/test_brevitas_mobilenet.py b/tests/brevitas/test_brevitas_mobilenet.py index b1475b6f4..c84052417 100644 --- a/tests/brevitas/test_brevitas_mobilenet.py +++ b/tests/brevitas/test_brevitas_mobilenet.py @@ -28,9 +28,9 @@ import pytest -import brevitas.onnx as bo import numpy as np import torch +from brevitas.export import export_finn_onnx from PIL import Image from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper @@ -54,7 +54,6 @@ from finn.util.test import crop_center, get_test_model_trained, resize_smaller_s @pytest.mark.brevitas_export -@pytest.mark.xfail def test_brevitas_mobilenet(): # get single image as input and prepare image img = Image.open(get_finn_root() + "/tests/brevitas/king_charles.jpg") @@ -76,7 +75,7 @@ def test_brevitas_mobilenet(): std = 0.226 ch = 3 preproc = NormalizePreProc(mean, std, ch) - bo.export_finn_onnx(preproc, (1, 3, 224, 224), preproc_onnx) + export_finn_onnx(preproc, torch.randn(1, 3, 224, 224), preproc_onnx) preproc_model = ModelWrapper(preproc_onnx) # set input finn datatype to UINT8 preproc_model.set_tensor_datatype( @@ -89,7 +88,7 @@ def test_brevitas_mobilenet(): finn_onnx = export_onnx_path + "/quant_mobilenet_v1_4b_exported.onnx" mobilenet = get_test_model_trained("mobilenet", 4, 4) - bo.export_finn_onnx(mobilenet, (1, 3, 224, 224), finn_onnx) + export_finn_onnx(mobilenet, torch.randn(1, 3, 224, 224), finn_onnx) # do forward pass in PyTorch/Brevitas input_tensor = preproc.forward(img_torch) diff --git a/tests/brevitas/test_brevitas_non_scaled_quanthardtanh_export.py b/tests/brevitas/test_brevitas_non_scaled_quanthardtanh_export.py index 5d70acb10..ad6a7e53d 100644 --- a/tests/brevitas/test_brevitas_non_scaled_quanthardtanh_export.py +++ b/tests/brevitas/test_brevitas_non_scaled_quanthardtanh_export.py @@ -28,7 +28,6 @@ import pytest -import brevitas.onnx as bo import numpy as np import onnx # noqa import os @@ -36,7 +35,7 @@ import torch from brevitas.core.quant import QuantType from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from brevitas.nn import QuantHardTanh from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.infer_shapes import InferShapes @@ -78,13 +77,13 @@ def test_brevitas_act_export_qhardtanh_nonscaled( ) if QONNX_export: m_path = export_onnx_path - BrevitasONNXManager.export(b_act, ishape, m_path) + export_qonnx(b_act, torch.randn(ishape), m_path) qonnx_cleanup(m_path, out_file=m_path) model = ModelWrapper(m_path) model = model.transform(ConvertQONNXtoFINN()) model.save(m_path) else: - bo.export_finn_onnx(b_act, ishape, export_onnx_path) + export_finn_onnx(b_act, torch.randn(ishape), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( diff --git a/tests/brevitas/test_brevitas_qconv2d.py b/tests/brevitas/test_brevitas_qconv2d.py index 214c55e5f..faeb3ff48 100644 --- a/tests/brevitas/test_brevitas_qconv2d.py +++ b/tests/brevitas/test_brevitas_qconv2d.py @@ -28,7 +28,6 @@ import pytest -import brevitas.onnx as bo import numpy as np import os import torch @@ -36,7 +35,7 @@ from brevitas.core.quant import QuantType from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType from brevitas.core.stats import StatsOp -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from brevitas.nn import QuantConv2d from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper @@ -96,13 +95,13 @@ def test_brevitas_QConv2d(dw, bias, in_channels, QONNX_export): b_conv.eval() if QONNX_export: m_path = export_onnx_path - BrevitasONNXManager.export(b_conv, ishape, m_path) + export_qonnx(b_conv, torch.randn(ishape), m_path) qonnx_cleanup(m_path, out_file=m_path) model = ModelWrapper(m_path) model = model.transform(ConvertQONNXtoFINN()) model.save(m_path) else: - bo.export_finn_onnx(b_conv, ishape, export_onnx_path) + export_finn_onnx(b_conv, torch.randn(ishape), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) inp_tensor = np.random.uniform(low=-1.0, high=1.0, size=ishape).astype(np.float32) diff --git a/tests/brevitas/test_brevitas_qlinear.py b/tests/brevitas/test_brevitas_qlinear.py index bcd75a545..1ad52fb5d 100644 --- a/tests/brevitas/test_brevitas_qlinear.py +++ b/tests/brevitas/test_brevitas_qlinear.py @@ -28,12 +28,11 @@ import pytest -import brevitas.onnx as bo import numpy as np import os import torch from brevitas.core.quant import QuantType -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from brevitas.nn import QuantLinear from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper @@ -75,13 +74,13 @@ def test_brevitas_qlinear( b_linear.eval() if QONNX_export: m_path = export_onnx_path - BrevitasONNXManager.export(b_linear, i_shape, m_path) + export_qonnx(b_linear, torch.randn(i_shape), m_path) qonnx_cleanup(m_path, out_file=m_path) model = ModelWrapper(m_path) model = model.transform(ConvertQONNXtoFINN()) model.save(m_path) else: - bo.export_finn_onnx(b_linear, i_shape, export_onnx_path) + export_finn_onnx(b_linear, torch.randn(i_shape), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) inp_tensor = gen_finn_dt_tensor(i_dtype, i_shape) diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py index 3dc46ec31..1900763bd 100644 --- a/tests/brevitas/test_brevitas_relu_act_export.py +++ b/tests/brevitas/test_brevitas_relu_act_export.py @@ -28,7 +28,6 @@ import pytest -import brevitas.onnx as bo import numpy as np import onnx # noqa import os @@ -36,7 +35,7 @@ import torch from brevitas.core.quant import QuantType from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from brevitas.nn import QuantReLU from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.infer_shapes import InferShapes @@ -51,18 +50,16 @@ export_onnx_path = "test_brevitas_relu_act_export.onnx" @pytest.mark.brevitas_export @pytest.mark.parametrize("abits", [2, 4, 8]) -@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)]) @pytest.mark.parametrize( "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER] ) @pytest.mark.parametrize("QONNX_export", [False, True]) -def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type, QONNX_export): - min_val = -1.0 +def test_brevitas_act_export_relu(abits, scaling_impl_type, QONNX_export): ishape = (1, 15) b_act = QuantReLU( bit_width=abits, - max_val=max_val, + max_val=6.0, scaling_impl_type=scaling_impl_type, restrict_scaling_type=RestrictValueType.LOG_FP, quant_type=QuantType.INT, @@ -79,18 +76,16 @@ scaling_impl.learned_value": torch.tensor( b_act.load_state_dict(checkpoint) if QONNX_export: m_path = export_onnx_path - BrevitasONNXManager.export(b_act, ishape, m_path) + export_qonnx(b_act, torch.randn(ishape), m_path) qonnx_cleanup(m_path, out_file=m_path) model = ModelWrapper(m_path) model = model.transform(ConvertQONNXtoFINN()) model.save(m_path) else: - bo.export_finn_onnx(b_act, ishape, export_onnx_path) + export_finn_onnx(b_act, torch.randn(ishape), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) - inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( - np.float32 - ) + inp_tensor = np.random.uniform(low=-1.0, high=6.0, size=ishape).astype(np.float32) idict = {model.graph.input[0].name: inp_tensor} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] @@ -98,7 +93,7 @@ scaling_impl.learned_value": torch.tensor( b_act.eval() expected = b_act.forward(inp_tensor).detach().numpy() if not np.isclose(produced, expected, atol=1e-3).all(): - print(abits, max_val, scaling_impl_type) + print(abits, scaling_impl_type) print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach()) if abits < 5: print( @@ -115,27 +110,25 @@ scaling_impl.learned_value": torch.tensor( @pytest.mark.brevitas_export @pytest.mark.parametrize("abits", [2, 4, 8]) -@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)]) -@pytest.mark.parametrize("scaling_per_channel", [True, False]) +@pytest.mark.parametrize("scaling_per_output_channel", [True, False]) @pytest.mark.parametrize("QONNX_export", [False, True]) def test_brevitas_act_export_relu_imagenet( - abits, max_val, scaling_per_channel, QONNX_export + abits, scaling_per_output_channel, QONNX_export ): out_channels = 32 ishape = (1, out_channels, 1, 1) - min_val = -1.0 b_act = QuantReLU( bit_width=abits, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER, - scaling_per_channel=scaling_per_channel, + scaling_per_output_channel=scaling_per_output_channel, restrict_scaling_type=RestrictValueType.LOG_FP, scaling_min_val=2e-16, max_val=6.0, return_quant_tensor=False, per_channel_broadcastable_shape=(1, out_channels, 1, 1), ) - if scaling_per_channel is True: + if scaling_per_output_channel is True: rand_tensor = (2) * torch.rand((1, out_channels, 1, 1)) else: rand_tensor = torch.tensor(1.2398) @@ -148,18 +141,16 @@ scaling_impl.learned_value": rand_tensor.type( b_act.load_state_dict(checkpoint) if QONNX_export: m_path = export_onnx_path - BrevitasONNXManager.export(b_act, ishape, m_path) + export_qonnx(b_act, torch.randn(ishape), m_path) qonnx_cleanup(m_path, out_file=m_path) model = ModelWrapper(m_path) model = model.transform(ConvertQONNXtoFINN()) model.save(m_path) else: - bo.export_finn_onnx(b_act, ishape, export_onnx_path) + export_finn_onnx(b_act, torch.randn(ishape), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) - inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( - np.float32 - ) + inp_tensor = np.random.uniform(low=-1.0, high=6.0, size=ishape).astype(np.float32) idict = {model.graph.input[0].name: inp_tensor} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] @@ -167,7 +158,7 @@ scaling_impl.learned_value": rand_tensor.type( b_act.eval() expected = b_act.forward(inp_tensor).detach().numpy() if not np.isclose(produced, expected, atol=1e-3).all(): - print(abits, max_val) + print(abits) print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach()) if abits < 5: print( @@ -190,7 +181,7 @@ class PyTorchTestModel(nn.Module): bit_width=abits, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER, - scaling_per_channel=True, + scaling_per_output_channel=True, restrict_scaling_type=RestrictValueType.LOG_FP, scaling_min_val=2e-16, max_val=6.0, @@ -208,15 +199,13 @@ class PyTorchTestModel(nn.Module): @pytest.mark.brevitas_export @pytest.mark.parametrize("abits", [2, 4, 8]) -@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)]) -@pytest.mark.parametrize("scaling_per_channel", [True]) +@pytest.mark.parametrize("scaling_per_output_channel", [True]) @pytest.mark.parametrize("QONNX_export", [True]) def test_brevitas_act_export_relu_forking( - abits, max_val, scaling_per_channel, QONNX_export + abits, scaling_per_output_channel, QONNX_export ): out_channels = 32 ishape = (1, out_channels, 1, 1) - min_val = -1.0 model_pyt = PyTorchTestModel(abits) rand_tensor = (2) * torch.rand((1, out_channels, 1, 1)) @@ -229,7 +218,7 @@ def test_brevitas_act_export_relu_forking( if QONNX_export: m_path = export_onnx_path - BrevitasONNXManager.export(model_pyt, ishape, m_path) + export_qonnx(model_pyt, torch.randn(ishape), m_path) qonnx_cleanup(m_path, out_file=m_path) model = ModelWrapper(m_path) model = model.transform(ConvertQONNXtoFINN()) @@ -237,9 +226,7 @@ def test_brevitas_act_export_relu_forking( model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) - inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( - np.float32 - ) + inp_tensor = np.random.uniform(low=-1.0, high=6.0, size=ishape).astype(np.float32) idict = {model.graph.input[0].name: inp_tensor} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] @@ -247,7 +234,7 @@ def test_brevitas_act_export_relu_forking( model_pyt.eval() expected = model_pyt.forward(inp_tensor).detach().numpy() if not np.isclose(produced, expected, atol=1e-3).all(): - print(abits, max_val) + print(abits) print("scale: ", model_pyt.quant_act_scale().type(torch.FloatTensor).detach()) if abits < 5: print( diff --git a/tests/brevitas/test_brevitas_scaled_qhardtanh_export.py b/tests/brevitas/test_brevitas_scaled_qhardtanh_export.py index 403d40610..d35cc8d2d 100644 --- a/tests/brevitas/test_brevitas_scaled_qhardtanh_export.py +++ b/tests/brevitas/test_brevitas_scaled_qhardtanh_export.py @@ -28,7 +28,6 @@ import pytest -import brevitas.onnx as bo import numpy as np import onnx # noqa import os @@ -36,7 +35,7 @@ import torch from brevitas.core.quant import QuantType from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType -from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.export import export_finn_onnx, export_qonnx from brevitas.nn import QuantHardTanh from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.infer_shapes import InferShapes @@ -91,13 +90,13 @@ tensor_quant.scaling_impl.learned_value": torch.tensor( b_act.load_state_dict(checkpoint) if QONNX_export: m_path = export_onnx_path - BrevitasONNXManager.export(b_act, ishape, m_path) + export_qonnx(b_act, torch.randn(ishape), m_path) qonnx_cleanup(m_path, out_file=m_path) model = ModelWrapper(m_path) model = model.transform(ConvertQONNXtoFINN()) model.save(m_path) else: - bo.export_finn_onnx(b_act, ishape, export_onnx_path) + export_finn_onnx(b_act, torch.randn(ishape), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( diff --git a/tests/brevitas/test_brevitas_validate_mobilenet.py b/tests/brevitas/test_brevitas_validate_mobilenet.py index 55915838e..20e8ddad5 100644 --- a/tests/brevitas/test_brevitas_validate_mobilenet.py +++ b/tests/brevitas/test_brevitas_validate_mobilenet.py @@ -35,6 +35,7 @@ import os import torch import torchvision.datasets as datasets import torchvision.transforms as transforms +from brevitas.export import export_finn_onnx from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.fold_constants import FoldConstants from qonnx.transformation.general import ( @@ -113,7 +114,7 @@ def test_brevitas_compare_exported_mobilenet(): # export preprocessing preproc_onnx = export_onnx_path + "/quant_mobilenet_v1_4b_preproc.onnx" preproc = NormalizePreProc(mean, std, ch) - bo.export_finn_onnx(preproc, (1, 3, 224, 224), preproc_onnx) + export_finn_onnx(preproc, torch.randn(1, 3, 224, 224), preproc_onnx) preproc_model = ModelWrapper(preproc_onnx) preproc_model = preproc_model.transform(InferShapes()) preproc_model = preproc_model.transform(GiveUniqueNodeNames()) @@ -124,7 +125,7 @@ def test_brevitas_compare_exported_mobilenet(): mobilenet = get_test_model_trained("mobilenet", 4, 4) if debug_mode: dbg_hook = bo.enable_debug(mobilenet) - bo.export_finn_onnx(mobilenet, (1, 3, 224, 224), finn_onnx) + export_finn_onnx(mobilenet, torch.randn(1, 3, 224, 224), finn_onnx) model = ModelWrapper(finn_onnx) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) -- GitLab