Skip to content
Snippets Groups Projects
Commit 0de0d0ef authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] fix LFC io buf names, handle redownload better

parent 4537a319
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,10 @@ mnist_onnx_local_dir = "/tmp/mnist_onnx"
def test_mnist_onnx_download_extract_run():
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f:
......@@ -32,8 +36,7 @@ def test_mnist_onnx_download_extract_run():
input_dict = {"Input3": np_helper.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
assert np.isclose(
np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"],
atol=1e-3
np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
......
......@@ -121,6 +121,10 @@ def test_brevitas_to_onnx_export_and_exec():
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
# load one of the test vectors
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment