From fa32c273920a8d84cb2612aed798b1a1d933533c Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 21 Nov 2019 22:30:39 +0000 Subject: [PATCH] [Test] update test_brevitas_cnv to use new transform interface --- tests/test_brevitas_cnv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py index c777d0bd5..a0904e37d 100644 --- a/tests/test_brevitas_cnv.py +++ b/tests/test_brevitas_cnv.py @@ -14,8 +14,8 @@ from models.common import ( from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential import finn.core.onnx_exec as oxe -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes # QuantConv2d configuration CNV_OUT_CH_POOL = [ @@ -147,7 +147,7 @@ def test_brevitas_cnv_export_exec(): cnv.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) model = ModelWrapper(export_onnx_path) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) model.save(export_onnx_path) fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") input_tensor = np.load(fn)["arr_0"].astype(np.float32) -- GitLab