Skip to content
Snippets Groups Projects
Commit 4537a319 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add skeleton for exported Brevitas ONNX exec test

parent de00f7fc
No related branches found
No related tags found
No related merge requests found
import os
import shutil
from functools import reduce
from operator import mul
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
import finn.core.onnx_exec as oxe
FC_OUT_FEATURES = [1024, 1024, 1024]
INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
LAST_FC_PER_OUT_CH_SCALING = False
IN_DROPOUT = 0.2
HIDDEN_DROPOUT = 0.2
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
export_onnx_path = "test_output_lfc.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_lfc_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
)
class LFC(Module):
def __init__(
......@@ -70,7 +85,6 @@ class LFC(Module):
def test_brevitas_to_onnx_export():
export_onnx_path = "test_output_lfc.onnx"
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
......@@ -99,3 +113,45 @@ def test_brevitas_to_onnx_export():
int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
assert (int_weights_onnx == int_weights_pytorch).all()
os.remove(export_onnx_path)
def test_brevitas_to_onnx_export_and_exec():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
# load one of the test vectors
input_tensor = onnx.TensorProto()
output_tensor = onnx.TensorProto()
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/output_0.pb", "rb") as f:
output_tensor.ParseFromString(f.read())
# run using FINN-based execution
input_dict = {"0": nph.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
assert np.isclose(nph.to_array(output_tensor), output_dict["53"], atol=1e-3).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
os.remove(export_onnx_path)
class objdict(dict):
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError("No such attribute: " + name)
def __setattr__(self, name, value):
self[name] = value
def __delattr__(self, name):
if name in self:
del self[name]
else:
raise AttributeError("No such attribute: " + name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment