Skip to content
Snippets Groups Projects
Commit 46416756 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] refactor tests to use local ONNX conv MNIST model

parent 75224f18
No related branches found
No related tags found
No related merge requests found
import hashlib
import os
import shutil
from pkgutil import get_data
import numpy as np
import onnx
import onnx.numpy_helper as np_helper
import wget
import finn.core.onnx_exec as oxe
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
def test_mnist_onnx_download_extract_run():
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f:
assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d"
# load the onnx model
model = ModelWrapper(mnist_onnx_local_dir + "/mnist/model.onnx")
raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
model = ModelWrapper(raw_m)
model = model.transform_single(si.infer_shapes)
# load one of the test vectors
input_tensor = onnx.TensorProto()
output_tensor = onnx.TensorProto()
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/output_0.pb", "rb") as f:
output_tensor.ParseFromString(f.read())
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
output_tensor = onnx.load_tensor_from_string(raw_o)
# run using FINN-based execution
input_dict = {"Input3": np_helper.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
assert np.isclose(
np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
import os
import shutil
from functools import reduce
from operator import mul
from pkgutil import get_data
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
......@@ -23,9 +22,6 @@ LAST_FC_PER_OUT_CH_SCALING = False
IN_DROPOUT = 0.2
HIDDEN_DROPOUT = 0.2
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
export_onnx_path = "test_output_lfc.onnx"
transformed_onnx_path = "test_output_lfc_transformed.onnx"
# TODO get from config instead, hardcoded to Docker path for now
......@@ -98,21 +94,11 @@ def test_batchnorm_to_affine():
model = ModelWrapper(export_onnx_path)
model = model.transform_single(si.infer_shapes)
new_model = model.transform_single(tx.batchnorm_to_affine)
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
# load one of the test vectors
input_tensor = onnx.TensorProto()
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
input_dict = {"0": nph.to_array(input_tensor)}
output_original = oxe.execute_onnx(model, input_dict)["53"]
output_transformed = oxe.execute_onnx(new_model, input_dict)["53"]
assert np.isclose(output_transformed, output_original, atol=1e-3).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
os.remove(export_onnx_path)
import os
import shutil
from functools import reduce
from operator import mul
from pkgutil import get_data
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
......@@ -22,9 +21,6 @@ LAST_FC_PER_OUT_CH_SCALING = False
IN_DROPOUT = 0.2
HIDDEN_DROPOUT = 0.2
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
export_onnx_path = "test_output_lfc.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_lfc_checkpoint = (
......@@ -124,17 +120,9 @@ def test_brevitas_trained_lfc_pytorch():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval()
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
# download some MNIST test data
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
# load one of the test vectors
input_tensor = onnx.TensorProto()
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
assert input_tensor.shape == (1, 1, 28, 28)
# do forward pass in PyTorch/Brevitas
......@@ -154,9 +142,6 @@ def test_brevitas_trained_lfc_pytorch():
]
]
assert np.isclose(produced, expected, atol=1e-4).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
def test_brevitas_to_onnx_export_and_exec():
......@@ -166,16 +151,9 @@ def test_brevitas_to_onnx_export_and_exec():
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform_single(si.infer_shapes)
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
# load one of the test vectors
input_tensor = onnx.TensorProto()
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
# run using FINN-based execution
input_dict = {"0": nph.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
......@@ -187,6 +165,4 @@ def test_brevitas_to_onnx_export_and_exec():
expected = lfc.forward(input_tensor).detach().numpy()
assert np.isclose(produced, expected, atol=1e-3).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
os.remove(export_onnx_path)
import hashlib
import os
import shutil
import wget
from pkgutil import get_data
import finn.transformation.general as tg
from finn.core.modelwrapper import ModelWrapper
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
def test_give_unique_node_names():
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f:
assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d"
model = ModelWrapper(mnist_onnx_local_dir + "/mnist/model.onnx")
raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
model = ModelWrapper(raw_m)
model = model.transform_single(tg.give_unique_node_names)
assert model.graph.node[0].name == "Reshape_0"
assert model.graph.node[1].name == "Conv_1"
assert model.graph.node[11].name == "Add_11"
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
import hashlib
import os
import shutil
import wget
from pkgutil import get_data
import finn.transformation.general as tg
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
def test_renaming():
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f:
assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d"
# load the onnx model
model = ModelWrapper(mnist_onnx_local_dir + "/mnist/model.onnx")
raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
model = ModelWrapper(raw_m)
model = model.transform_single(si.infer_shapes)
model = model.transform_single(tg.give_unique_node_names)
model = model.transform_single(tg.give_readable_tensor_names)
......@@ -33,6 +18,3 @@ def test_renaming():
assert model.graph.node[6].op_type == "Add"
assert model.graph.node[6].name == "Add_6"
assert model.graph.node[6].input[1] == "Add_6_param0"
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment