diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py index c777d0bd57d6d152e900dd6977979020c840926f..a0904e37dd9e9dc00d8e0bc52193ae5faff00c76 100644 --- a/tests/test_brevitas_cnv.py +++ b/tests/test_brevitas_cnv.py @@ -14,8 +14,8 @@ from models.common import ( from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential import finn.core.onnx_exec as oxe -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes # QuantConv2d configuration CNV_OUT_CH_POOL = [ @@ -147,7 +147,7 @@ def test_brevitas_cnv_export_exec(): cnv.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) model = ModelWrapper(export_onnx_path) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) model.save(export_onnx_path) fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") input_tensor = np.load(fn)["arr_0"].astype(np.float32)