Skip to content
Snippets Groups Projects
Commit 137c42b1 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add test_brevitas_trained_lfc_w1a2_pytorch

parent 204bd617
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,10 @@ trained_lfc_w1a1_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
)
trained_lfc_w1a2_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W2A/checkpoints/best.tar"
)
def test_brevitas_trained_lfc_w1a1_pytorch():
# load pretrained weights into LFC-w1a1
......@@ -49,6 +53,35 @@ def test_brevitas_trained_lfc_w1a1_pytorch():
assert np.isclose(produced, expected, atol=1e-4).all()
def test_brevitas_trained_lfc_w1a2_pytorch():
# load pretrained weights into LFC-w1a2
lfc = LFC(weight_bit_width=1, act_bit_width=2, in_bit_width=2).eval()
checkpoint = torch.load(trained_lfc_w1a2_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
# load one of the test vectors
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
assert input_tensor.shape == (1, 1, 28, 28)
# do forward pass in PyTorch/Brevitas
produced = lfc.forward(input_tensor).detach().numpy()
expected = [
[
4.598069,
-6.3698025,
10.75695,
0.3796571,
1.4764442,
-5.4417515,
-1.8982856,
-5.610488,
6.116698,
0.21092065,
]
]
assert np.isclose(produced, expected, atol=1e-4).all()
def test_brevitas_to_onnx_export_and_exec_lfc_w1a1():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment