From 7ed029428def5705b0d4e9fd7de19903ef004499 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 21 Nov 2019 22:53:22 +0000 Subject: [PATCH] [Test] add PyTorch golden value for cnv-w1a1 --- tests/test_brevitas_cnv.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py index 631e8073f..6c2eeaa7e 100644 --- a/tests/test_brevitas_cnv.py +++ b/tests/test_brevitas_cnv.py @@ -63,5 +63,22 @@ def test_brevitas_trained_cnv_w1a1_pytorch(): input_tensor = torch.from_numpy(input_tensor).float() assert input_tensor.shape == (1, 3, 32, 32) # do forward pass in PyTorch/Brevitas - cnv.forward(input_tensor).detach().numpy() - # TODO verify produced answer + produced = cnv.forward(input_tensor).detach().numpy() + expected = np.asarray( + [ + [ + 3.7939777, + -2.3108773, + 0.06898145, + 0.55185133, + 0.37939775, + -1.9659703, + -0.3104164, + -2.828238, + 2.6902752, + 0.48286998, + ] + ], + dtype=np.float32, + ) + assert np.isclose(produced, expected, atol=1e-3).all() -- GitLab