diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index f513a79da10d9d93cc34d659157ea39f7c13dabb..821306ec673d5eed6d8b13360cd5a74254fb0c4a 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -117,6 +117,46 @@ def test_brevitas_to_onnx_export(): os.remove(export_onnx_path) +def test_brevitas_trained_lfc_pytorch(): + # load pretrained weights into LFC-w1a1 + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval() + checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + # download some MNIST test data + try: + os.remove("/tmp/" + mnist_onnx_filename) + except OSError: + pass + dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp") + shutil.unpack_archive(dl_ret, mnist_onnx_local_dir) + # load one of the test vectors + input_tensor = onnx.TensorProto() + with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f: + input_tensor.ParseFromString(f.read()) + input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float() + assert input_tensor.shape == (1, 1, 28, 28) + # do forward pass in PyTorch/Brevitas + produced = lfc.forward(input_tensor).detach().numpy() + expected = [ + [ + 3.3253, + -2.5652, + 9.2157, + -1.4251, + 1.4251, + -3.3728, + 0.2850, + -0.5700, + 7.0781, + -1.2826, + ] + ] + assert np.isclose(produced, expected, atol=1e-4).all() + # remove the downloaded model and extracted files + os.remove(dl_ret) + shutil.rmtree(mnist_onnx_local_dir) + + def test_brevitas_to_onnx_export_and_exec(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")