diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py index 55ce1171608fa18c249c0dc110b1e605d082b9cd..49a8db86f48384a730c879e9c308f46c7a9f8016 100644 --- a/tests/test_brevitas_cnv.py +++ b/tests/test_brevitas_cnv.py @@ -1,5 +1,6 @@ import pkg_resources as pk +import brevitas.onnx as bo import numpy as np import torch from models.common import ( @@ -11,6 +12,10 @@ from models.common import ( ) from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential +import finn.core.onnx_exec as oxe +import finn.transformation.infer_shapes as si +from finn.core.modelwrapper import ModelWrapper + # QuantConv2d configuration CNV_OUT_CH_POOL = [ (0, 64, False), @@ -32,6 +37,7 @@ LAST_FC_PER_OUT_CH_SCALING = False # MaxPool2d configuration POOL_SIZE = 2 +export_onnx_path = "test_output_cnv.onnx" # TODO get from config instead, hardcoded to Docker path for now trained_cnv_checkpoint = ( "/workspace/brevitas_cnv_lfc/pretrained_models/CNV_1W1A/checkpoints/best.tar" @@ -121,3 +127,35 @@ def test_brevitas_trained_cnv_pytorch(): # do forward pass in PyTorch/Brevitas cnv.forward(input_tensor).detach().numpy() # TODO verify produced answer + + +def test_brevitas_cnv_export(): + cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval() + bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) + model = ModelWrapper(export_onnx_path) + assert model.graph.node[2].op_type == "Sign" + assert model.graph.node[3].op_type == "Conv" + conv0_wname = model.graph.node[3].input[1] + assert list(model.get_initializer(conv0_wname).shape) == [64, 3, 3, 3] + assert model.graph.node[4].op_type == "Mul" + + +def test_brevitas_cnv_export_exec(): + cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval() + checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu") + cnv.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform_single(si.infer_shapes) + model.save(export_onnx_path) + fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") + input_tensor = np.load(fn)["arr_0"] + assert input_tensor.shape == (1, 3, 32, 32) + # run using FINN-based execution + input_dict = {"0": input_tensor} + output_dict = oxe.execute_onnx(model, input_dict) + produced = output_dict[list(output_dict.keys())[0]] + # do forward pass in PyTorch/Brevitas + input_tensor = torch.from_numpy(input_tensor).float() + expected = cnv.forward(input_tensor).detach().numpy() + assert np.isclose(produced, expected, atol=1e-3).all()