diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index 669601ecb6ebfd6758d3382ab097a1e93dc848c7..9c35910366dda25e9e3fccf8789bfdaac90f26f4 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -30,8 +30,7 @@ import pytest
 import numpy as np
 import os
 import torch
-from brevitas.export import FINNManager
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from brevitas.nn import QuantAvgPool2d
 from brevitas.quant_tensor import QuantTensor
 from qonnx.core.datatype import DataType
@@ -97,14 +96,14 @@ def test_brevitas_avg_pool_export(
 
     # export
     if QONNX_export:
-        BrevitasONNXManager.export(
+        export_qonnx(
             quant_avgpool,
             export_path=export_onnx_path,
             input_t=input_quant_tensor,
         )
         model = ModelWrapper(export_onnx_path)
 
-        # Statically set the additional inputs generated by the BrevitasONNXManager
+        # Statically set the additional inputs generated by the Brevitas ONNX export
         model.graph.input.remove(model.graph.input[3])
         model.graph.input.remove(model.graph.input[2])
         model.graph.input.remove(model.graph.input[1])
@@ -118,7 +117,7 @@ def test_brevitas_avg_pool_export(
         model = model.transform(ConvertQONNXtoFINN())
         model.save(export_onnx_path)
     else:
-        FINNManager.export(
+        export_finn_onnx(
             quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
         )
     model = ModelWrapper(export_onnx_path)
diff --git a/tests/brevitas/test_brevitas_cnv.py b/tests/brevitas/test_brevitas_cnv.py
index 62aab2e3c2b85c6462c24194c917bdc2d8eec448..1a96815105b70a9bc58d51a8214c15bbc09aa69c 100644
--- a/tests/brevitas/test_brevitas_cnv.py
+++ b/tests/brevitas/test_brevitas_cnv.py
@@ -30,11 +30,10 @@ import pkg_resources as pk
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import os
 import torch
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.fold_constants import FoldConstants
 from qonnx.transformation.general import GiveUniqueNodeNames, RemoveStaticGraphInputs
@@ -58,13 +57,13 @@ def test_brevitas_cnv_export_exec(wbits, abits, QONNX_export):
     cnv = get_test_model_trained("CNV", wbits, abits)
     ishape = (1, 3, 32, 32)
     if QONNX_export:
-        BrevitasONNXManager.export(cnv, ishape, export_onnx_path)
+        export_qonnx(cnv, torch.randn(ishape), export_onnx_path)
         qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
         model = ModelWrapper(export_onnx_path)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(export_onnx_path)
     else:
-        bo.export_finn_onnx(cnv, ishape, export_onnx_path)
+        export_finn_onnx(cnv, torch.randn(ishape), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(InferShapes())
diff --git a/tests/brevitas/test_brevitas_debug.py b/tests/brevitas/test_brevitas_debug.py
index 181d610fff7a703a8ccbcf3bbb19bed2e5d7e89d..547c026e2174e1b46a0e72967076f32db73b18a5 100644
--- a/tests/brevitas/test_brevitas_debug.py
+++ b/tests/brevitas/test_brevitas_debug.py
@@ -34,7 +34,7 @@ import onnx
 import onnx.numpy_helper as nph
 import os
 import torch
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from pkgutil import get_data
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.fold_constants import FoldConstants
@@ -58,7 +58,7 @@ def test_brevitas_debug(QONNX_export, QONNX_FINN_conversion):
     ishape = (1, 1, 28, 28)
     if QONNX_export:
         dbg_hook = bo.enable_debug(fc, proxy_level=True)
-        BrevitasONNXManager.export(fc, ishape, finn_onnx)
+        export_qonnx(fc, torch.randn(ishape), finn_onnx)
         # DebugMarkers have the brevitas.onnx domain, so that needs adjusting
         model = ModelWrapper(finn_onnx)
         dbg_nodes = model.get_nodes_by_op_type("DebugMarker")
@@ -72,7 +72,7 @@ def test_brevitas_debug(QONNX_export, QONNX_FINN_conversion):
             model.save(finn_onnx)
     else:
         dbg_hook = bo.enable_debug(fc)
-        bo.export_finn_onnx(fc, ishape, finn_onnx)
+        export_finn_onnx(fc, torch.randn(ishape), finn_onnx)
         model = ModelWrapper(finn_onnx)
         # DebugMarkers have the brevitas.onnx domain, so that needs adjusting
         # ToDo: We should probably have transformation pass, which does this
diff --git a/tests/brevitas/test_brevitas_fc.py b/tests/brevitas/test_brevitas_fc.py
index 211fdb629b7c0465a145a094bab428064227afc9..3aaa96f9a5f74112cdfe2a90c425eec55661a3b1 100644
--- a/tests/brevitas/test_brevitas_fc.py
+++ b/tests/brevitas/test_brevitas_fc.py
@@ -28,12 +28,11 @@
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import onnx
 import onnx.numpy_helper as nph
 import torch
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from pkgutil import get_data
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.fold_constants import FoldConstants
@@ -68,13 +67,13 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, QONNX_export):
     fc = get_test_model_trained(size, wbits, abits)
     ishape = (1, 1, 28, 28)
     if QONNX_export:
-        BrevitasONNXManager.export(fc, ishape, finn_onnx)
+        export_qonnx(fc, torch.randn(ishape), finn_onnx)
         qonnx_cleanup(finn_onnx, out_file=finn_onnx)
         model = ModelWrapper(finn_onnx)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(finn_onnx)
     else:
-        bo.export_finn_onnx(fc, ishape, finn_onnx)
+        export_finn_onnx(fc, torch.randn(ishape), finn_onnx)
     model = ModelWrapper(finn_onnx)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
diff --git a/tests/brevitas/test_brevitas_mobilenet.py b/tests/brevitas/test_brevitas_mobilenet.py
index b1475b6f4ec8c4a6ed34b4249b961031780d4be8..c8405241722e28a28652b0cc1857f25a4aa1dc6e 100644
--- a/tests/brevitas/test_brevitas_mobilenet.py
+++ b/tests/brevitas/test_brevitas_mobilenet.py
@@ -28,9 +28,9 @@
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import torch
+from brevitas.export import export_finn_onnx
 from PIL import Image
 from qonnx.core.datatype import DataType
 from qonnx.core.modelwrapper import ModelWrapper
@@ -54,7 +54,6 @@ from finn.util.test import crop_center, get_test_model_trained, resize_smaller_s
 
 
 @pytest.mark.brevitas_export
-@pytest.mark.xfail
 def test_brevitas_mobilenet():
     # get single image as input and prepare image
     img = Image.open(get_finn_root() + "/tests/brevitas/king_charles.jpg")
@@ -76,7 +75,7 @@ def test_brevitas_mobilenet():
     std = 0.226
     ch = 3
     preproc = NormalizePreProc(mean, std, ch)
-    bo.export_finn_onnx(preproc, (1, 3, 224, 224), preproc_onnx)
+    export_finn_onnx(preproc, torch.randn(1, 3, 224, 224), preproc_onnx)
     preproc_model = ModelWrapper(preproc_onnx)
     # set input finn datatype to UINT8
     preproc_model.set_tensor_datatype(
@@ -89,7 +88,7 @@ def test_brevitas_mobilenet():
 
     finn_onnx = export_onnx_path + "/quant_mobilenet_v1_4b_exported.onnx"
     mobilenet = get_test_model_trained("mobilenet", 4, 4)
-    bo.export_finn_onnx(mobilenet, (1, 3, 224, 224), finn_onnx)
+    export_finn_onnx(mobilenet, torch.randn(1, 3, 224, 224), finn_onnx)
 
     # do forward pass in PyTorch/Brevitas
     input_tensor = preproc.forward(img_torch)
diff --git a/tests/brevitas/test_brevitas_non_scaled_quanthardtanh_export.py b/tests/brevitas/test_brevitas_non_scaled_quanthardtanh_export.py
index 5d70acb10264dc10a3681589075507f06a9c903b..ad6a7e53de993b76f5b35dadd4e257c8bd88f4de 100644
--- a/tests/brevitas/test_brevitas_non_scaled_quanthardtanh_export.py
+++ b/tests/brevitas/test_brevitas_non_scaled_quanthardtanh_export.py
@@ -28,7 +28,6 @@
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import onnx  # noqa
 import os
@@ -36,7 +35,7 @@ import torch
 from brevitas.core.quant import QuantType
 from brevitas.core.restrict_val import RestrictValueType
 from brevitas.core.scaling import ScalingImplType
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from brevitas.nn import QuantHardTanh
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.infer_shapes import InferShapes
@@ -78,13 +77,13 @@ def test_brevitas_act_export_qhardtanh_nonscaled(
     )
     if QONNX_export:
         m_path = export_onnx_path
-        BrevitasONNXManager.export(b_act, ishape, m_path)
+        export_qonnx(b_act, torch.randn(ishape), m_path)
         qonnx_cleanup(m_path, out_file=m_path)
         model = ModelWrapper(m_path)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(m_path)
     else:
-        bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+        export_finn_onnx(b_act, torch.randn(ishape), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
diff --git a/tests/brevitas/test_brevitas_qconv2d.py b/tests/brevitas/test_brevitas_qconv2d.py
index 214c55e5fd8b8c25c1ccca880f76690556af6397..faeb3ff48e2d7157008a87eab544766c83dc37d2 100644
--- a/tests/brevitas/test_brevitas_qconv2d.py
+++ b/tests/brevitas/test_brevitas_qconv2d.py
@@ -28,7 +28,6 @@
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import os
 import torch
@@ -36,7 +35,7 @@ from brevitas.core.quant import QuantType
 from brevitas.core.restrict_val import RestrictValueType
 from brevitas.core.scaling import ScalingImplType
 from brevitas.core.stats import StatsOp
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from brevitas.nn import QuantConv2d
 from qonnx.core.datatype import DataType
 from qonnx.core.modelwrapper import ModelWrapper
@@ -96,13 +95,13 @@ def test_brevitas_QConv2d(dw, bias, in_channels, QONNX_export):
     b_conv.eval()
     if QONNX_export:
         m_path = export_onnx_path
-        BrevitasONNXManager.export(b_conv, ishape, m_path)
+        export_qonnx(b_conv, torch.randn(ishape), m_path)
         qonnx_cleanup(m_path, out_file=m_path)
         model = ModelWrapper(m_path)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(m_path)
     else:
-        bo.export_finn_onnx(b_conv, ishape, export_onnx_path)
+        export_finn_onnx(b_conv, torch.randn(ishape), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     inp_tensor = np.random.uniform(low=-1.0, high=1.0, size=ishape).astype(np.float32)
diff --git a/tests/brevitas/test_brevitas_qlinear.py b/tests/brevitas/test_brevitas_qlinear.py
index bcd75a545544122c1faacf4c321b19a489defe85..1ad52fb5df9fff6584fb6b649481377f32fa666d 100644
--- a/tests/brevitas/test_brevitas_qlinear.py
+++ b/tests/brevitas/test_brevitas_qlinear.py
@@ -28,12 +28,11 @@
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import os
 import torch
 from brevitas.core.quant import QuantType
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from brevitas.nn import QuantLinear
 from qonnx.core.datatype import DataType
 from qonnx.core.modelwrapper import ModelWrapper
@@ -75,13 +74,13 @@ def test_brevitas_qlinear(
     b_linear.eval()
     if QONNX_export:
         m_path = export_onnx_path
-        BrevitasONNXManager.export(b_linear, i_shape, m_path)
+        export_qonnx(b_linear, torch.randn(i_shape), m_path)
         qonnx_cleanup(m_path, out_file=m_path)
         model = ModelWrapper(m_path)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(m_path)
     else:
-        bo.export_finn_onnx(b_linear, i_shape, export_onnx_path)
+        export_finn_onnx(b_linear, torch.randn(i_shape), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     inp_tensor = gen_finn_dt_tensor(i_dtype, i_shape)
diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py
index 3dc46ec31e49d7115b19b3373d54be6ddc29bb80..1900763bdd4d8c70369abc4f2ba0c33b02607e26 100644
--- a/tests/brevitas/test_brevitas_relu_act_export.py
+++ b/tests/brevitas/test_brevitas_relu_act_export.py
@@ -28,7 +28,6 @@
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import onnx  # noqa
 import os
@@ -36,7 +35,7 @@ import torch
 from brevitas.core.quant import QuantType
 from brevitas.core.restrict_val import RestrictValueType
 from brevitas.core.scaling import ScalingImplType
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from brevitas.nn import QuantReLU
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.infer_shapes import InferShapes
@@ -51,18 +50,16 @@ export_onnx_path = "test_brevitas_relu_act_export.onnx"
 
 @pytest.mark.brevitas_export
 @pytest.mark.parametrize("abits", [2, 4, 8])
-@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
 @pytest.mark.parametrize(
     "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER]
 )
 @pytest.mark.parametrize("QONNX_export", [False, True])
-def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type, QONNX_export):
-    min_val = -1.0
+def test_brevitas_act_export_relu(abits, scaling_impl_type, QONNX_export):
     ishape = (1, 15)
 
     b_act = QuantReLU(
         bit_width=abits,
-        max_val=max_val,
+        max_val=6.0,
         scaling_impl_type=scaling_impl_type,
         restrict_scaling_type=RestrictValueType.LOG_FP,
         quant_type=QuantType.INT,
@@ -79,18 +76,16 @@ scaling_impl.learned_value": torch.tensor(
         b_act.load_state_dict(checkpoint)
     if QONNX_export:
         m_path = export_onnx_path
-        BrevitasONNXManager.export(b_act, ishape, m_path)
+        export_qonnx(b_act, torch.randn(ishape), m_path)
         qonnx_cleanup(m_path, out_file=m_path)
         model = ModelWrapper(m_path)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(m_path)
     else:
-        bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+        export_finn_onnx(b_act, torch.randn(ishape), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
-    inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
-        np.float32
-    )
+    inp_tensor = np.random.uniform(low=-1.0, high=6.0, size=ishape).astype(np.float32)
     idict = {model.graph.input[0].name: inp_tensor}
     odict = oxe.execute_onnx(model, idict, True)
     produced = odict[model.graph.output[0].name]
@@ -98,7 +93,7 @@ scaling_impl.learned_value": torch.tensor(
     b_act.eval()
     expected = b_act.forward(inp_tensor).detach().numpy()
     if not np.isclose(produced, expected, atol=1e-3).all():
-        print(abits, max_val, scaling_impl_type)
+        print(abits, scaling_impl_type)
         print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
         if abits < 5:
             print(
@@ -115,27 +110,25 @@ scaling_impl.learned_value": torch.tensor(
 
 @pytest.mark.brevitas_export
 @pytest.mark.parametrize("abits", [2, 4, 8])
-@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
-@pytest.mark.parametrize("scaling_per_channel", [True, False])
+@pytest.mark.parametrize("scaling_per_output_channel", [True, False])
 @pytest.mark.parametrize("QONNX_export", [False, True])
 def test_brevitas_act_export_relu_imagenet(
-    abits, max_val, scaling_per_channel, QONNX_export
+    abits, scaling_per_output_channel, QONNX_export
 ):
     out_channels = 32
     ishape = (1, out_channels, 1, 1)
-    min_val = -1.0
     b_act = QuantReLU(
         bit_width=abits,
         quant_type=QuantType.INT,
         scaling_impl_type=ScalingImplType.PARAMETER,
-        scaling_per_channel=scaling_per_channel,
+        scaling_per_output_channel=scaling_per_output_channel,
         restrict_scaling_type=RestrictValueType.LOG_FP,
         scaling_min_val=2e-16,
         max_val=6.0,
         return_quant_tensor=False,
         per_channel_broadcastable_shape=(1, out_channels, 1, 1),
     )
-    if scaling_per_channel is True:
+    if scaling_per_output_channel is True:
         rand_tensor = (2) * torch.rand((1, out_channels, 1, 1))
     else:
         rand_tensor = torch.tensor(1.2398)
@@ -148,18 +141,16 @@ scaling_impl.learned_value": rand_tensor.type(
     b_act.load_state_dict(checkpoint)
     if QONNX_export:
         m_path = export_onnx_path
-        BrevitasONNXManager.export(b_act, ishape, m_path)
+        export_qonnx(b_act, torch.randn(ishape), m_path)
         qonnx_cleanup(m_path, out_file=m_path)
         model = ModelWrapper(m_path)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(m_path)
     else:
-        bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+        export_finn_onnx(b_act, torch.randn(ishape), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
-    inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
-        np.float32
-    )
+    inp_tensor = np.random.uniform(low=-1.0, high=6.0, size=ishape).astype(np.float32)
     idict = {model.graph.input[0].name: inp_tensor}
     odict = oxe.execute_onnx(model, idict, True)
     produced = odict[model.graph.output[0].name]
@@ -167,7 +158,7 @@ scaling_impl.learned_value": rand_tensor.type(
     b_act.eval()
     expected = b_act.forward(inp_tensor).detach().numpy()
     if not np.isclose(produced, expected, atol=1e-3).all():
-        print(abits, max_val)
+        print(abits)
         print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
         if abits < 5:
             print(
@@ -190,7 +181,7 @@ class PyTorchTestModel(nn.Module):
             bit_width=abits,
             quant_type=QuantType.INT,
             scaling_impl_type=ScalingImplType.PARAMETER,
-            scaling_per_channel=True,
+            scaling_per_output_channel=True,
             restrict_scaling_type=RestrictValueType.LOG_FP,
             scaling_min_val=2e-16,
             max_val=6.0,
@@ -208,15 +199,13 @@ class PyTorchTestModel(nn.Module):
 
 @pytest.mark.brevitas_export
 @pytest.mark.parametrize("abits", [2, 4, 8])
-@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
-@pytest.mark.parametrize("scaling_per_channel", [True])
+@pytest.mark.parametrize("scaling_per_output_channel", [True])
 @pytest.mark.parametrize("QONNX_export", [True])
 def test_brevitas_act_export_relu_forking(
-    abits, max_val, scaling_per_channel, QONNX_export
+    abits, scaling_per_output_channel, QONNX_export
 ):
     out_channels = 32
     ishape = (1, out_channels, 1, 1)
-    min_val = -1.0
     model_pyt = PyTorchTestModel(abits)
 
     rand_tensor = (2) * torch.rand((1, out_channels, 1, 1))
@@ -229,7 +218,7 @@ def test_brevitas_act_export_relu_forking(
 
     if QONNX_export:
         m_path = export_onnx_path
-        BrevitasONNXManager.export(model_pyt, ishape, m_path)
+        export_qonnx(model_pyt, torch.randn(ishape), m_path)
         qonnx_cleanup(m_path, out_file=m_path)
         model = ModelWrapper(m_path)
         model = model.transform(ConvertQONNXtoFINN())
@@ -237,9 +226,7 @@ def test_brevitas_act_export_relu_forking(
 
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
-    inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
-        np.float32
-    )
+    inp_tensor = np.random.uniform(low=-1.0, high=6.0, size=ishape).astype(np.float32)
     idict = {model.graph.input[0].name: inp_tensor}
     odict = oxe.execute_onnx(model, idict, True)
     produced = odict[model.graph.output[0].name]
@@ -247,7 +234,7 @@ def test_brevitas_act_export_relu_forking(
     model_pyt.eval()
     expected = model_pyt.forward(inp_tensor).detach().numpy()
     if not np.isclose(produced, expected, atol=1e-3).all():
-        print(abits, max_val)
+        print(abits)
         print("scale: ", model_pyt.quant_act_scale().type(torch.FloatTensor).detach())
         if abits < 5:
             print(
diff --git a/tests/brevitas/test_brevitas_scaled_qhardtanh_export.py b/tests/brevitas/test_brevitas_scaled_qhardtanh_export.py
index 403d406105e8e60e6ef87f833c495dc2974de68c..d35cc8d2dda58f2be188622cdac59c19cee25e13 100644
--- a/tests/brevitas/test_brevitas_scaled_qhardtanh_export.py
+++ b/tests/brevitas/test_brevitas_scaled_qhardtanh_export.py
@@ -28,7 +28,6 @@
 
 import pytest
 
-import brevitas.onnx as bo
 import numpy as np
 import onnx  # noqa
 import os
@@ -36,7 +35,7 @@ import torch
 from brevitas.core.quant import QuantType
 from brevitas.core.restrict_val import RestrictValueType
 from brevitas.core.scaling import ScalingImplType
-from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.export import export_finn_onnx, export_qonnx
 from brevitas.nn import QuantHardTanh
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.infer_shapes import InferShapes
@@ -91,13 +90,13 @@ tensor_quant.scaling_impl.learned_value": torch.tensor(
         b_act.load_state_dict(checkpoint)
     if QONNX_export:
         m_path = export_onnx_path
-        BrevitasONNXManager.export(b_act, ishape, m_path)
+        export_qonnx(b_act, torch.randn(ishape), m_path)
         qonnx_cleanup(m_path, out_file=m_path)
         model = ModelWrapper(m_path)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(m_path)
     else:
-        bo.export_finn_onnx(b_act, ishape, export_onnx_path)
+        export_finn_onnx(b_act, torch.randn(ishape), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
diff --git a/tests/brevitas/test_brevitas_validate_mobilenet.py b/tests/brevitas/test_brevitas_validate_mobilenet.py
index 55915838e8a10d19d3aa6446d0bb667785bbd905..20e8ddad501e8b07502decef6eacd4afe061917a 100644
--- a/tests/brevitas/test_brevitas_validate_mobilenet.py
+++ b/tests/brevitas/test_brevitas_validate_mobilenet.py
@@ -35,6 +35,7 @@ import os
 import torch
 import torchvision.datasets as datasets
 import torchvision.transforms as transforms
+from brevitas.export import export_finn_onnx
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.transformation.fold_constants import FoldConstants
 from qonnx.transformation.general import (
@@ -113,7 +114,7 @@ def test_brevitas_compare_exported_mobilenet():
     # export preprocessing
     preproc_onnx = export_onnx_path + "/quant_mobilenet_v1_4b_preproc.onnx"
     preproc = NormalizePreProc(mean, std, ch)
-    bo.export_finn_onnx(preproc, (1, 3, 224, 224), preproc_onnx)
+    export_finn_onnx(preproc, torch.randn(1, 3, 224, 224), preproc_onnx)
     preproc_model = ModelWrapper(preproc_onnx)
     preproc_model = preproc_model.transform(InferShapes())
     preproc_model = preproc_model.transform(GiveUniqueNodeNames())
@@ -124,7 +125,7 @@ def test_brevitas_compare_exported_mobilenet():
     mobilenet = get_test_model_trained("mobilenet", 4, 4)
     if debug_mode:
         dbg_hook = bo.enable_debug(mobilenet)
-    bo.export_finn_onnx(mobilenet, (1, 3, 224, 224), finn_onnx)
+    export_finn_onnx(mobilenet, torch.randn(1, 3, 224, 224), finn_onnx)
     model = ModelWrapper(finn_onnx)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())