diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py
index c777d0bd57d6d152e900dd6977979020c840926f..a0904e37dd9e9dc00d8e0bc52193ae5faff00c76 100644
--- a/tests/test_brevitas_cnv.py
+++ b/tests/test_brevitas_cnv.py
@@ -14,8 +14,8 @@ from models.common import (
 from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 # QuantConv2d configuration
 CNV_OUT_CH_POOL = [
@@ -147,7 +147,7 @@ def test_brevitas_cnv_export_exec():
     cnv.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.save(export_onnx_path)
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
     input_tensor = np.load(fn)["arr_0"].astype(np.float32)