Skip to content
Snippets Groups Projects
Commit 523d2b57 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] move tests to ModelWrapper and FINN shape inference

parent d5037489
No related branches found
No related tags found
No related merge requests found
......@@ -5,10 +5,11 @@ import shutil
import numpy as np
import onnx
import onnx.numpy_helper as np_helper
import onnx.shape_inference as si
import wget
import finn.core.onnx_exec as oxe
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
......@@ -25,10 +26,8 @@ def test_mnist_onnx_download_extract_run():
with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f:
assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d"
# load the onnx model
model = onnx.load(mnist_onnx_local_dir + "/mnist/model.onnx")
# call ONNX shape inference to make sure we have value_info fields for all
# the intermediate tensors in the graph
model = si.infer_shapes(model)
model = ModelWrapper(mnist_onnx_local_dir + "/mnist/model.onnx")
model = model.transform_single(si.infer_shapes)
# load one of the test vectors
input_tensor = onnx.TensorProto()
output_tensor = onnx.TensorProto()
......@@ -38,7 +37,7 @@ def test_mnist_onnx_download_extract_run():
output_tensor.ParseFromString(f.read())
# run using FINN-based execution
input_dict = {"Input3": np_helper.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
output_dict = oxe.execute_onnx(model.model, input_dict)
assert np.isclose(
np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
).all()
......
......@@ -7,7 +7,6 @@ import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import onnx.shape_inference as si
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
......@@ -15,6 +14,7 @@ from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
import finn.core.onnx_exec as oxe
import finn.transformation.batchnorm_to_affine as tx
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
FC_OUT_FEATURES = [1024, 1024, 1024]
......@@ -96,7 +96,7 @@ def test_batchnorm_to_affine():
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model.model = si.infer_shapes(model.model)
model = model.transform_single(si.infer_shapes)
new_model = model.transform_single(tx.batchnorm_to_affine)
try:
os.remove("/tmp/" + mnist_onnx_filename)
......
......@@ -7,13 +7,14 @@ import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import onnx.shape_inference as si
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
import finn.core.onnx_exec as oxe
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
FC_OUT_FEATURES = [1024, 1024, 1024]
INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
......@@ -90,7 +91,7 @@ class LFC(Module):
def test_brevitas_to_onnx_export():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
model = ModelWrapper(export_onnx_path)
# TODO the following way of testing is highly sensitive to small changes
# in PyTorch ONNX export: the order, names, count... of nodes could
# easily change between different versions, and break this test.
......@@ -163,10 +164,8 @@ def test_brevitas_to_onnx_export_and_exec():
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
# call ONNX shape inference to make sure we have value_info fields for all
# the intermediate tensors in the graph
model = si.infer_shapes(model)
model = ModelWrapper(export_onnx_path)
model = model.transform_single(si.infer_shapes)
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
......@@ -179,7 +178,7 @@ def test_brevitas_to_onnx_export_and_exec():
input_tensor.ParseFromString(f.read())
# run using FINN-based execution
input_dict = {"0": nph.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
output_dict = oxe.execute_onnx(model.model, input_dict)
produced = output_dict[list(output_dict.keys())[0]]
# run using PyTorch/Brevitas
input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment