Skip to content
Snippets Groups Projects
Commit 7ed02942 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add PyTorch golden value for cnv-w1a1

parent 441a25a6
No related branches found
No related tags found
No related merge requests found
......@@ -63,5 +63,22 @@ def test_brevitas_trained_cnv_w1a1_pytorch():
input_tensor = torch.from_numpy(input_tensor).float()
assert input_tensor.shape == (1, 3, 32, 32)
# do forward pass in PyTorch/Brevitas
cnv.forward(input_tensor).detach().numpy()
# TODO verify produced answer
produced = cnv.forward(input_tensor).detach().numpy()
expected = np.asarray(
[
[
3.7939777,
-2.3108773,
0.06898145,
0.55185133,
0.37939775,
-1.9659703,
-0.3104164,
-2.828238,
2.6902752,
0.48286998,
]
],
dtype=np.float32,
)
assert np.isclose(produced, expected, atol=1e-3).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment