Skip to content
Snippets Groups Projects
Commit 43323e6d authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add a test for PyTorch/Brevitas fwd pass for LFC-w1a1

parent 22dbe6c8
No related branches found
No related tags found
No related merge requests found
......@@ -117,6 +117,46 @@ def test_brevitas_to_onnx_export():
os.remove(export_onnx_path)
def test_brevitas_trained_lfc_pytorch():
# load pretrained weights into LFC-w1a1
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval()
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
# download some MNIST test data
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
pass
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
# load one of the test vectors
input_tensor = onnx.TensorProto()
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
assert input_tensor.shape == (1, 1, 28, 28)
# do forward pass in PyTorch/Brevitas
produced = lfc.forward(input_tensor).detach().numpy()
expected = [
[
3.3253,
-2.5652,
9.2157,
-1.4251,
1.4251,
-3.3728,
0.2850,
-0.5700,
7.0781,
-1.2826,
]
]
assert np.isclose(produced, expected, atol=1e-4).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
def test_brevitas_to_onnx_export_and_exec():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment