diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py
index a0904e37dd9e9dc00d8e0bc52193ae5faff00c76..631e8073f437052958d6c8aa22126bda88468c49 100644
--- a/tests/test_brevitas_cnv.py
+++ b/tests/test_brevitas_cnv.py
@@ -4,40 +4,13 @@ import pkg_resources as pk
 import brevitas.onnx as bo
 import numpy as np
 import torch
-from models.common import (
-    get_act_quant,
-    get_quant_conv2d,
-    get_quant_linear,
-    get_quant_type,
-    get_stats_op
-)
-from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential
+from models.CNV import CNV
 
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.infer_shapes import InferShapes
 
-# QuantConv2d configuration
-CNV_OUT_CH_POOL = [
-    (0, 64, False),
-    (1, 64, True),
-    (2, 128, False),
-    (3, 128, True),
-    (4, 256, False),
-    (5, 256, False),
-]
-
-# Intermediate QuantLinear configuration
-INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
-INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)]
-
-# Last QuantLinear configuration
-LAST_FC_IN_FEATURES = 512
-LAST_FC_PER_OUT_CH_SCALING = False
-
-# MaxPool2d configuration
-POOL_SIZE = 2
-
 export_onnx_path = "test_output_cnv.onnx"
 # TODO get from config instead, hardcoded to Docker path for now
 trained_cnv_checkpoint = (
@@ -45,92 +18,7 @@ trained_cnv_checkpoint = (
 )
 
 
-class CNV(Module):
-    def __init__(
-        self,
-        num_classes=10,
-        weight_bit_width=None,
-        act_bit_width=None,
-        in_bit_width=None,
-        in_ch=3,
-    ):
-        super(CNV, self).__init__()
-
-        weight_quant_type = get_quant_type(weight_bit_width)
-        act_quant_type = get_quant_type(act_bit_width)
-        in_quant_type = get_quant_type(in_bit_width)
-        stats_op = get_stats_op(weight_quant_type)
-
-        self.conv_features = ModuleList()
-        self.linear_features = ModuleList()
-        self.conv_features.append(get_act_quant(in_bit_width, in_quant_type))
-
-        for i, out_ch, is_pool_enabled in CNV_OUT_CH_POOL:
-            self.conv_features.append(
-                get_quant_conv2d(
-                    in_ch=in_ch,
-                    out_ch=out_ch,
-                    bit_width=weight_bit_width,
-                    quant_type=weight_quant_type,
-                    stats_op=stats_op,
-                )
-            )
-            in_ch = out_ch
-            if is_pool_enabled:
-                self.conv_features.append(MaxPool2d(kernel_size=2))
-            if i == 5:
-                self.conv_features.append(Sequential())
-            self.conv_features.append(BatchNorm2d(in_ch))
-            self.conv_features.append(get_act_quant(act_bit_width, act_quant_type))
-
-        for in_features, out_features in INTERMEDIATE_FC_FEATURES:
-            self.linear_features.append(
-                get_quant_linear(
-                    in_features=in_features,
-                    out_features=out_features,
-                    per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
-                    bit_width=weight_bit_width,
-                    quant_type=weight_quant_type,
-                    stats_op=stats_op,
-                )
-            )
-            self.linear_features.append(BatchNorm1d(out_features))
-            self.linear_features.append(get_act_quant(act_bit_width, act_quant_type))
-        self.fc = get_quant_linear(
-            in_features=LAST_FC_IN_FEATURES,
-            out_features=num_classes,
-            per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
-            bit_width=weight_bit_width,
-            quant_type=weight_quant_type,
-            stats_op=stats_op,
-        )
-
-    def forward(self, x):
-        x = 2.0 * x - torch.tensor([1.0])
-        for mod in self.conv_features:
-            x = mod(x)
-        x = x.view(1, 256)
-        for mod in self.linear_features:
-            x = mod(x)
-        out = self.fc(x)
-        return out
-
-
-def test_brevitas_trained_cnv_pytorch():
-    # load pretrained weights into CNV-w1a1
-    cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
-    checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
-    cnv.load_state_dict(checkpoint["state_dict"])
-    fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
-    input_tensor = np.load(fn)["arr_0"]
-    input_tensor = torch.from_numpy(input_tensor).float()
-    assert input_tensor.shape == (1, 3, 32, 32)
-    # do forward pass in PyTorch/Brevitas
-    cnv.forward(input_tensor).detach().numpy()
-    # TODO verify produced answer
-
-
-def test_brevitas_cnv_export():
+def test_brevitas_cnv_w1a1_export():
     cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
     bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
@@ -139,15 +27,17 @@ def test_brevitas_cnv_export():
     conv0_wname = model.graph.node[3].input[1]
     assert list(model.get_initializer(conv0_wname).shape) == [64, 3, 3, 3]
     assert model.graph.node[4].op_type == "Mul"
+    os.remove(export_onnx_path)
 
 
-def test_brevitas_cnv_export_exec():
+def test_brevitas_cnv_w1a1_export_exec():
     cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
     checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
     cnv.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
     model.save(export_onnx_path)
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
     input_tensor = np.load(fn)["arr_0"].astype(np.float32)
@@ -161,3 +51,17 @@ def test_brevitas_cnv_export_exec():
     expected = cnv.forward(input_tensor).detach().numpy()
     assert np.isclose(produced, expected, atol=1e-3).all()
     os.remove(export_onnx_path)
+
+
+def test_brevitas_trained_cnv_w1a1_pytorch():
+    # load pretrained weights into CNV-w1a1
+    cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
+    checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
+    cnv.load_state_dict(checkpoint["state_dict"])
+    fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
+    input_tensor = np.load(fn)["arr_0"]
+    input_tensor = torch.from_numpy(input_tensor).float()
+    assert input_tensor.shape == (1, 3, 32, 32)
+    # do forward pass in PyTorch/Brevitas
+    cnv.forward(input_tensor).detach().numpy()
+    # TODO verify produced answer