Skip to content
Snippets Groups Projects
Commit 65720c2f authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add MNIST top-1 accuracy measurement using stitched rtlsim

parent a5de4f75
No related branches found
No related tags found
No related merge requests found
......@@ -65,6 +65,7 @@ RUN apt update; apt install nano
RUN pip install pytest-dependency
RUN pip install pytest-xdist
RUN pip install pytest-parallel
RUN pip install mnist
ENV PYTHONPATH "${PYTHONPATH}:/workspace/finn/src"
ENV PYTHONPATH "${PYTHONPATH}:/workspace/pyverilator"
......
......@@ -63,6 +63,7 @@ RUN pip install sphinx_rtd_theme==0.5.0
RUN pip install pytest-xdist==2.0.0
RUN pip install pytest-parallel==0.1.0
RUN pip install netron==4.4.7
RUN pip install mnist
# switch user
RUN groupadd -g $GID $GNAME
......
......@@ -92,11 +92,14 @@ from finn.util.pytorch import ToTensor
from finn.transformation.merge_onnx_models import MergeONNXModels
from finn.transformation.insert_topk import InsertTopK
from finn.core.datatype import DataType
import mnist
build_dir = "/tmp/" + os.environ["FINN_INST_NAME"]
target_clk_ns = 10
mem_mode = "decoupled"
rtlsim_trace = False
mnist_test_imgs = mnist.test_images()
mnist_test_labels = mnist.test_labels()
def get_checkpoint_name(topology, wbits, abits, step):
......@@ -393,6 +396,31 @@ class TestEnd2End:
warnings.warn("Estimated & rtlsim performance: " + str(perf))
assert np.isclose(y, output_tensor_npy).all()
@pytest.mark.slow
@pytest.mark.parametrize("kind", ["zynq"])
def test_rtlsim_top1(self, topology, wbits, abits, kind):
if "fc" not in topology:
pytest.skip("Top-1 rtlsim test currently for MNIST only")
rtlsim_chkpt = get_checkpoint_name(
topology, wbits, abits, "ipstitch_rtlsim_" + kind
)
parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent")
load_test_checkpoint_or_skip(rtlsim_chkpt)
ok = 0
nok = 0
for i in range(10000):
tdata = mnist_test_imgs[i].reshape(1, 1, 28, 28).astype(np.float32)
exp = mnist_test_labels[i].item()
y = execute_parent(parent_chkpt, rtlsim_chkpt, tdata)
ret = y.item()
if ret == exp:
ok += 1
else:
nok += 1
acc_top1 = ok * 100.0 / (ok + nok)
warnings.warn("Final OK %d NOK %d top-1 %f" % (ok, nok, acc_top1))
assert acc_top1 > 90.0
@pytest.mark.slow
@pytest.mark.vivado
@pytest.mark.parametrize("kind", ["zynq", "alveo"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment