From 204bd6176812aecec03be81624b1a9e98ab984a8 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sat, 9 Nov 2019 13:46:50 +0000 Subject: [PATCH] [Test] rename to test_brevitas_trained_lfc_w1a1_pytorch since w1a2 is coming --- tests/test_brevitas_export.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 80850edb7..04d7776df 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -15,15 +15,15 @@ from finn.core.modelwrapper import ModelWrapper export_onnx_path = "test_output_lfc.onnx" # TODO get from config instead, hardcoded to Docker path for now -trained_lfc_checkpoint = ( +trained_lfc_w1a1_checkpoint = ( "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" ) -def test_brevitas_trained_lfc_pytorch(): +def test_brevitas_trained_lfc_w1a1_pytorch(): # load pretrained weights into LFC-w1a1 lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval() - checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu") lfc.load_state_dict(checkpoint["state_dict"]) # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") @@ -49,9 +49,9 @@ def test_brevitas_trained_lfc_pytorch(): assert np.isclose(produced, expected, atol=1e-4).all() -def test_brevitas_to_onnx_export_and_exec(): +def test_brevitas_to_onnx_export_and_exec_lfc_w1a1(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) - checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu") lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) -- GitLab