Skip to content
Snippets Groups Projects
Commit 204bd617 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] rename to test_brevitas_trained_lfc_w1a1_pytorch

since w1a2 is coming
parent f5fef066
No related branches found
No related tags found
No related merge requests found
......@@ -15,15 +15,15 @@ from finn.core.modelwrapper import ModelWrapper
export_onnx_path = "test_output_lfc.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_lfc_checkpoint = (
trained_lfc_w1a1_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
)
def test_brevitas_trained_lfc_pytorch():
def test_brevitas_trained_lfc_w1a1_pytorch():
# load pretrained weights into LFC-w1a1
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval()
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
# load one of the test vectors
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
......@@ -49,9 +49,9 @@ def test_brevitas_trained_lfc_pytorch():
assert np.isclose(produced, expected, atol=1e-4).all()
def test_brevitas_to_onnx_export_and_exec():
def test_brevitas_to_onnx_export_and_exec_lfc_w1a1():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment