From b6493935727c648e2242a7d64eee95d8bdf3c733 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sun, 20 Sep 2020 23:08:53 +0200 Subject: [PATCH] [Driver] generate validation code for mnist/cifar-10 end2end examples --- .../fpgadataflow/make_pynq_driver.py | 7 +++ .../transformation/fpgadataflow/templates.py | 52 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py index 0e50213ee..813b40698 100644 --- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py +++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py @@ -124,6 +124,13 @@ class MakePYNQDriver(Transformation): with open(driver_py, "w") as f: f.write(driver) + + # add validate.py to run full top-1 test (only for suitable networks) + validate_py = pynq_driver_dir + "/validate.py" + validate_src = templates.pynq_validation_template + with open(validate_py, "w") as f: + f.write(validate_src) + # copy all the dependencies into the driver folder shutil.copytree( get_finn_root() + "/src/finn/util", pynq_driver_dir + "/finn/util" diff --git a/src/finn/transformation/fpgadataflow/templates.py b/src/finn/transformation/fpgadataflow/templates.py index 66580c70d..2b3789dc2 100644 --- a/src/finn/transformation/fpgadataflow/templates.py +++ b/src/finn/transformation/fpgadataflow/templates.py @@ -436,3 +436,55 @@ open_project $VITIS_PROJ_PATH$/_x/link/vivado/vpl/prj/prj.xpr open_run impl_1 report_utilization -hierarchical -hierarchical_depth 5 -file $VITIS_PROJ_PATH$/synth_report.xml -format xml """ + +pynq_validation_template = """ +import argparse +from driver import FINNAccelDriver +import numpy as np + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Validate top-1 accuracy for FINN accelerator') + parser.add_argument('--batchsize', help='number of samples for inference', type=int, default=100) + parser.add_argument('--dataset', help='dataset to use (mnist of cifar10)', required=True) + # parse arguments + args = parser.parse_args() + bsize = args.batchsize + dataset = args.dataset + + if dataset == "mnist": + from dataset_loading import mnist + trainx, trainy, testx, testy, valx, valy = mnist.load_mnist_data("/tmp", download=True, one_hot=False) + elif dataset == "cifar10": + from dataset_loading import cifar + trainx, trainy, testx, testy, valx, valy = cifar.load_cifar_data("/tmp", download=True, one_hot=False) + else: + raise Exception("Unrecognized dataset") + + test_imgs = testx + test_labels = testy + + ok = 0 + nok = 0 + total = test_imgs.shape[0] + driver = FINNAccelDriver(bsize, "resizer.bit", "zynq-iodma") + + n_batches = int(total / bsize) + + test_imgs = test_imgs.reshape(n_batches, bsize, -1) + test_labels = test_labels.reshape(n_batches, bsize) + + for i in range(n_batches): + ibuf_normal = test_imgs[i].reshape(driver.ibuf_packed_device.shape) + exp = test_labels[i] + driver.copy_input_data_to_device(ibuf_normal) + driver.execute() + obuf_normal = np.empty_like(driver.obuf_packed_device) + driver.copy_output_data_from_device(obuf_normal) + ret = np.bincount(obuf_normal.flatten() == exp.flatten()) + nok += ret[0] + ok += ret[1] + print("batch %d / %d : total OK %d NOK %d" % (i, n_batches, ok, nok)) + + acc = 100.0 * ok / (total) + print("Final accuracy: %f" % acc) +""" -- GitLab