diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py index 631e8073f437052958d6c8aa22126bda88468c49..6c2eeaa7e3c78c44ab0611d0f40223c6f41969bd 100644 --- a/tests/test_brevitas_cnv.py +++ b/tests/test_brevitas_cnv.py @@ -63,5 +63,22 @@ def test_brevitas_trained_cnv_w1a1_pytorch(): input_tensor = torch.from_numpy(input_tensor).float() assert input_tensor.shape == (1, 3, 32, 32) # do forward pass in PyTorch/Brevitas - cnv.forward(input_tensor).detach().numpy() - # TODO verify produced answer + produced = cnv.forward(input_tensor).detach().numpy() + expected = np.asarray( + [ + [ + 3.7939777, + -2.3108773, + 0.06898145, + 0.55185133, + 0.37939775, + -1.9659703, + -0.3104164, + -2.828238, + 2.6902752, + 0.48286998, + ] + ], + dtype=np.float32, + ) + assert np.isclose(produced, expected, atol=1e-3).all()