Skip to content
Snippets Groups Projects
Commit d1184c4b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] more tests for CNV export, shape inference failing

parent fe3e4e25
No related branches found
No related tags found
No related merge requests found
import pkg_resources as pk
import brevitas.onnx as bo
import numpy as np
import torch
from models.common import (
......@@ -11,6 +12,10 @@ from models.common import (
)
from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential
import finn.core.onnx_exec as oxe
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
# QuantConv2d configuration
CNV_OUT_CH_POOL = [
(0, 64, False),
......@@ -32,6 +37,7 @@ LAST_FC_PER_OUT_CH_SCALING = False
# MaxPool2d configuration
POOL_SIZE = 2
export_onnx_path = "test_output_cnv.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_cnv_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/CNV_1W1A/checkpoints/best.tar"
......@@ -121,3 +127,35 @@ def test_brevitas_trained_cnv_pytorch():
# do forward pass in PyTorch/Brevitas
cnv.forward(input_tensor).detach().numpy()
# TODO verify produced answer
def test_brevitas_cnv_export():
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
assert model.graph.node[2].op_type == "Sign"
assert model.graph.node[3].op_type == "Conv"
conv0_wname = model.graph.node[3].input[1]
assert list(model.get_initializer(conv0_wname).shape) == [64, 3, 3, 3]
assert model.graph.node[4].op_type == "Mul"
def test_brevitas_cnv_export_exec():
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
cnv.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform_single(si.infer_shapes)
model.save(export_onnx_path)
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"]
assert input_tensor.shape == (1, 3, 32, 32)
# run using FINN-based execution
input_dict = {"0": input_tensor}
output_dict = oxe.execute_onnx(model, input_dict)
produced = output_dict[list(output_dict.keys())[0]]
# do forward pass in PyTorch/Brevitas
input_tensor = torch.from_numpy(input_tensor).float()
expected = cnv.forward(input_tensor).detach().numpy()
assert np.isclose(produced, expected, atol=1e-3).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment