From 7ed029428def5705b0d4e9fd7de19903ef004499 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 21 Nov 2019 22:53:22 +0000
Subject: [PATCH] [Test] add PyTorch golden value for cnv-w1a1

---
 tests/test_brevitas_cnv.py | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py
index 631e8073f..6c2eeaa7e 100644
--- a/tests/test_brevitas_cnv.py
+++ b/tests/test_brevitas_cnv.py
@@ -63,5 +63,22 @@ def test_brevitas_trained_cnv_w1a1_pytorch():
     input_tensor = torch.from_numpy(input_tensor).float()
     assert input_tensor.shape == (1, 3, 32, 32)
     # do forward pass in PyTorch/Brevitas
-    cnv.forward(input_tensor).detach().numpy()
-    # TODO verify produced answer
+    produced = cnv.forward(input_tensor).detach().numpy()
+    expected = np.asarray(
+        [
+            [
+                3.7939777,
+                -2.3108773,
+                0.06898145,
+                0.55185133,
+                0.37939775,
+                -1.9659703,
+                -0.3104164,
+                -2.828238,
+                2.6902752,
+                0.48286998,
+            ]
+        ],
+        dtype=np.float32,
+    )
+    assert np.isclose(produced, expected, atol=1e-3).all()
-- 
GitLab