Skip to content
Snippets Groups Projects
Commit fa32c273 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] update test_brevitas_cnv to use new transform interface

parent 848b5910
No related branches found
No related tags found
No related merge requests found
......@@ -14,8 +14,8 @@ from models.common import (
from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential
import finn.core.onnx_exec as oxe
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
# QuantConv2d configuration
CNV_OUT_CH_POOL = [
......@@ -147,7 +147,7 @@ def test_brevitas_cnv_export_exec():
cnv.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform_single(si.infer_shapes)
model = model.transform(InferShapes())
model.save(export_onnx_path)
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"].astype(np.float32)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment