From 4537a31904491e5a0a63430b8b406c2f3e76476b Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Wed, 16 Oct 2019 16:42:09 +0100 Subject: [PATCH] [Test] add skeleton for exported Brevitas ONNX exec test --- tests/test_brevitas_export.py | 58 ++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index dbdaba600..ea5287025 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -1,19 +1,34 @@ import os +import shutil from functools import reduce from operator import mul import brevitas.onnx as bo +import numpy as np import onnx import onnx.numpy_helper as nph +import torch +import wget from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op from torch.nn import BatchNorm1d, Dropout, Module, ModuleList +import finn.core.onnx_exec as oxe + FC_OUT_FEATURES = [1024, 1024, 1024] INTERMEDIATE_FC_PER_OUT_CH_SCALING = True LAST_FC_PER_OUT_CH_SCALING = False IN_DROPOUT = 0.2 HIDDEN_DROPOUT = 0.2 +mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist" +mnist_onnx_filename = "mnist.tar.gz" +mnist_onnx_local_dir = "/tmp/mnist_onnx" +export_onnx_path = "test_output_lfc.onnx" +# TODO get from config instead, hardcoded to Docker path for now +trained_lfc_checkpoint = ( + "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" +) + class LFC(Module): def __init__( @@ -70,7 +85,6 @@ class LFC(Module): def test_brevitas_to_onnx_export(): - export_onnx_path = "test_output_lfc.onnx" lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = onnx.load(export_onnx_path) @@ -99,3 +113,45 @@ def test_brevitas_to_onnx_export(): int_weights_onnx = nph.to_array(model.graph.initializer[init_ind]) assert (int_weights_onnx == int_weights_pytorch).all() os.remove(export_onnx_path) + + +def test_brevitas_to_onnx_export_and_exec(): + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = onnx.load(export_onnx_path) + dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp") + shutil.unpack_archive(dl_ret, mnist_onnx_local_dir) + # load one of the test vectors + input_tensor = onnx.TensorProto() + output_tensor = onnx.TensorProto() + with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f: + input_tensor.ParseFromString(f.read()) + with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/output_0.pb", "rb") as f: + output_tensor.ParseFromString(f.read()) + # run using FINN-based execution + input_dict = {"0": nph.to_array(input_tensor)} + output_dict = oxe.execute_onnx(model, input_dict) + assert np.isclose(nph.to_array(output_tensor), output_dict["53"], atol=1e-3).all() + # remove the downloaded model and extracted files + os.remove(dl_ret) + shutil.rmtree(mnist_onnx_local_dir) + os.remove(export_onnx_path) + + +class objdict(dict): + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError("No such attribute: " + name) + + def __setattr__(self, name, value): + self[name] = value + + def __delattr__(self, name): + if name in self: + del self[name] + else: + raise AttributeError("No such attribute: " + name) -- GitLab