From fa32c273920a8d84cb2612aed798b1a1d933533c Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 21 Nov 2019 22:30:39 +0000
Subject: [PATCH] [Test] update test_brevitas_cnv to use new transform
 interface

---
 tests/test_brevitas_cnv.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py
index c777d0bd5..a0904e37d 100644
--- a/tests/test_brevitas_cnv.py
+++ b/tests/test_brevitas_cnv.py
@@ -14,8 +14,8 @@ from models.common import (
 from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 # QuantConv2d configuration
 CNV_OUT_CH_POOL = [
@@ -147,7 +147,7 @@ def test_brevitas_cnv_export_exec():
     cnv.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.save(export_onnx_path)
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
     input_tensor = np.load(fn)["arr_0"].astype(np.float32)
-- 
GitLab