diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index 4674c832741b38aa5dded3c9482d6028682f5fdf..4b88b0f787fb7780079ee82d6802d5cbff410748 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -25,31 +25,29 @@
 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
 import os
 
-import onnx  # noqa
 import torch
 import numpy as np
-import brevitas.onnx as bo
-from brevitas.nn import QuantAvgPool2d
-from brevitas.quant_tensor import QuantTensor
-from brevitas.core.quant import QuantType
+import pytest
+import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.core.datatype import DataType
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.util.basic import gen_finn_dt_tensor
-import finn.core.onnx_exec as oxe
 
-import pytest
+from brevitas.export import FINNManager
+from brevitas.nn import QuantAvgPool2d
+from brevitas.quant_tensor import QuantTensor
+
 
 export_onnx_path = "test_brevitas_avg_pool_export.onnx"
 
 
 @pytest.mark.parametrize("kernel_size", [2, 3])
 @pytest.mark.parametrize("stride", [1, 2])
-@pytest.mark.parametrize("signed", [False, True])
+@pytest.mark.parametrize("signed", [True, False])
 @pytest.mark.parametrize("bit_width", [2, 4])
 @pytest.mark.parametrize("input_bit_width", [4, 8, 16])
 @pytest.mark.parametrize("channels", [2, 4])
@@ -57,90 +55,46 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
 def test_brevitas_avg_pool_export(
     kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
 ):
-    ishape = (1, channels, idim, idim)
-    ibw_tensor = torch.Tensor([input_bit_width])
 
-    b_avgpool = QuantAvgPool2d(
-        kernel_size=kernel_size,
-        stride=stride,
-        bit_width=bit_width,
-        quant_type=QuantType.INT,
-    )
-    # call forward pass manually once to cache scale factor and bitwidth
-    input_tensor = torch.from_numpy(np.zeros(ishape)).float()
-    scale = np.ones((1, channels, 1, 1))
-    zpt = torch.from_numpy(np.zeros((1))).float()
-    output_scale = torch.from_numpy(scale).float()
-    input_quant_tensor = QuantTensor(
-        value=input_tensor,
-        scale=output_scale,
-        bit_width=ibw_tensor,
-        signed=signed,
-        zero_point=zpt,
-    )
-    bo.export_finn_onnx(
-        b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
+    quant_avgpool = QuantAvgPool2d(
+        kernel_size=kernel_size, stride=stride, bit_width=bit_width
     )
-    model = ModelWrapper(export_onnx_path)
+    quant_avgpool.eval()
 
-    # determine input FINN datatype
-    if signed is True:
-        prefix = "INT"
-    else:
-        prefix = "UINT"
+    # determine input
+    prefix = "INT" if signed else "UINT"
     dt_name = prefix + str(input_bit_width)
     dtype = DataType[dt_name]
-    model = model.transform(InferShapes())
-    model = model.transform(InferDataTypes())
-
-    # execution with input tensor using integers and scale = 1
-    # calculate golden output
-    inp = gen_finn_dt_tensor(dtype, ishape)
-    input_tensor = torch.from_numpy(inp).float()
-    input_quant_tensor = QuantTensor(
-        value=input_tensor,
-        scale=output_scale,
-        bit_width=ibw_tensor,
-        signed=signed,
-        zero_point=zpt,
-    )
-    b_avgpool.eval()
-    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
-
-    # finn execution
-    idict = {model.graph.input[0].name: inp}
-    odict = oxe.execute_onnx(model, idict, True)
-    produced = odict[model.graph.output[0].name]
-    assert (expected == produced).all()
-
-    # execution with input tensor using float and scale != 1
-    scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
+    input_shape = (1, channels, idim, idim)
+    input_array = gen_finn_dt_tensor(dtype, input_shape)
+    # Brevitas QuantAvgPool layers need QuantTensors to export correctly
+    # which requires setting up a QuantTensor instance with the scale
+    # factor, zero point, bitwidth and signedness
+    scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
         np.float32
     )
-    inp_tensor = inp * scale
-    input_tensor = torch.from_numpy(inp_tensor).float()
-    input_scale = torch.from_numpy(scale).float()
+    input_tensor = torch.from_numpy(input_array * scale_array).float()
+    scale_tensor = torch.from_numpy(scale_array).float()
+    zp = torch.tensor(0.0)
     input_quant_tensor = QuantTensor(
-        value=input_tensor,
-        scale=input_scale,
-        bit_width=ibw_tensor,
-        signed=signed,
-        zero_point=zpt,
+        input_tensor, scale_tensor, zp, input_bit_width, signed
     )
-    # export again to set the scale values correctly
-    bo.export_finn_onnx(
-        b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
+
+    # export
+    FINNManager.export(
+        quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
     )
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
-    b_avgpool.eval()
-    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
-    # finn execution
-    idict = {model.graph.input[0].name: inp_tensor}
-    odict = oxe.execute_onnx(model, idict, True)
-    produced = odict[model.graph.output[0].name]
-
-    assert np.isclose(expected, produced).all()
 
+    # reference brevitas output
+    ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
+    # finn output
+    idict = {model.graph.input[0].name: input_array}
+    odict = oxe.execute_onnx(model, idict, True)
+    finn_output = odict[model.graph.output[0].name]
+    # compare outputs
+    assert np.isclose(ref_output_array, finn_output).all()
+    # cleanup
     os.remove(export_onnx_path)