Skip to content
Snippets Groups Projects
Commit 28fcd28d authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] use test loaders for cnv export and exec testing

parent fc9bee01
No related branches found
No related tags found
No related merge requests found
......@@ -23,10 +23,10 @@ def get_test_model_trained(netname, wbits, abits):
fc = model_def_fxn(weight_bit_width=wbits, act_bit_width=abits, in_bit_width=abits)
checkpoint = torch.load(checkpoint_loc, map_location="cpu")
fc.load_state_dict(checkpoint["state_dict"])
return fc
return fc.eval()
def get_test_model_untrained(netname, wbits, abits):
model_def_fxn = get_test_model_def_fxn(netname)
fc = model_def_fxn(weight_bit_width=wbits, act_bit_width=abits, in_bit_width=abits)
return fc
return fc.eval()
......@@ -4,22 +4,18 @@ import pkg_resources as pk
import brevitas.onnx as bo
import numpy as np
import torch
from models.CNV import CNV
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.infer_shapes import InferShapes
from finn.util.test import get_test_model_trained, get_test_model_untrained
export_onnx_path = "test_output_cnv.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_cnv_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/CNV_1W1A/checkpoints/best.tar"
)
def test_brevitas_cnv_w1a1_export():
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
cnv = get_test_model_untrained("CNV", 1, 1)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
assert model.graph.node[2].op_type == "Sign"
......@@ -31,9 +27,7 @@ def test_brevitas_cnv_w1a1_export():
def test_brevitas_cnv_w1a1_export_exec():
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
cnv.load_state_dict(checkpoint["state_dict"])
cnv = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
......@@ -53,11 +47,9 @@ def test_brevitas_cnv_w1a1_export_exec():
os.remove(export_onnx_path)
def test_brevitas_trained_cnv_w1a1_pytorch():
def test_brevitas_cnv_w1a1_pytorch():
# load pretrained weights into CNV-w1a1
cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval()
checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu")
cnv.load_state_dict(checkpoint["state_dict"])
cnv = get_test_model_trained("CNV", 1, 1)
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"]
input_tensor = torch.from_numpy(input_tensor).float()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment