diff --git a/tests/brevitas/test_brevitas_qlinear.py b/tests/brevitas/test_brevitas_qlinear.py
index 873866b37727730b7cedd035f5edd93f7c1afe32..601da3c5fae8219a3a953af0b1f07fa9fad6d700 100644
--- a/tests/brevitas/test_brevitas_qlinear.py
+++ b/tests/brevitas/test_brevitas_qlinear.py
@@ -33,12 +33,15 @@ import numpy as np
 import os
 import torch
 from brevitas.core.quant import QuantType
+from brevitas.export.onnx.generic.manager import BrevitasONNXManager
 from brevitas.nn import QuantLinear
+from qonnx.util.cleanup import cleanup as qonnx_cleanup
 
 import finn.core.onnx_exec as oxe
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
 from finn.util.basic import gen_finn_dt_tensor
 
 export_onnx_path = "test_brevitas_qlinear.onnx"
@@ -49,7 +52,10 @@ export_onnx_path = "test_brevitas_qlinear.onnx"
 @pytest.mark.parametrize("in_features", [3])
 @pytest.mark.parametrize("w_bits", [4])
 @pytest.mark.parametrize("i_dtype", [DataType.UINT4])
-def test_brevitas_qlinear(bias, out_features, in_features, w_bits, i_dtype):
+@pytest.mark.parametrize("QONNX_export", [False, True])
+def test_brevitas_qlinear(
+    bias, out_features, in_features, w_bits, i_dtype, QONNX_export
+):
     i_shape = (1, in_features)
     w_shape = (out_features, in_features)
     b_linear = QuantLinear(
@@ -66,7 +72,15 @@ def test_brevitas_qlinear(bias, out_features, in_features, w_bits, i_dtype):
     )
     b_linear.weight.data = torch.from_numpy(weight_tensor_fp)
     b_linear.eval()
-    bo.export_finn_onnx(b_linear, i_shape, export_onnx_path)
+    if QONNX_export:
+        m_path = export_onnx_path
+        BrevitasONNXManager.export(b_linear, i_shape, m_path)
+        qonnx_cleanup(m_path, out_file=m_path)
+        model = ModelWrapper(m_path)
+        model = model.transform(ConvertQONNXtoFINN())
+        model.save(m_path)
+    else:
+        bo.export_finn_onnx(b_linear, i_shape, export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     inp_tensor = gen_finn_dt_tensor(i_dtype, i_shape)