From b6493935727c648e2242a7d64eee95d8bdf3c733 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sun, 20 Sep 2020 23:08:53 +0200
Subject: [PATCH] [Driver] generate validation code for mnist/cifar-10 end2end
 examples

---
 .../fpgadataflow/make_pynq_driver.py          |  7 +++
 .../transformation/fpgadataflow/templates.py  | 52 +++++++++++++++++++
 2 files changed, 59 insertions(+)

diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
index 0e50213ee..813b40698 100644
--- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py
+++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
@@ -124,6 +124,13 @@ class MakePYNQDriver(Transformation):
 
         with open(driver_py, "w") as f:
             f.write(driver)
+
+        # add validate.py to run full top-1 test (only for suitable networks)
+        validate_py = pynq_driver_dir + "/validate.py"
+        validate_src = templates.pynq_validation_template
+        with open(validate_py, "w") as f:
+            f.write(validate_src)
+
         # copy all the dependencies into the driver folder
         shutil.copytree(
             get_finn_root() + "/src/finn/util", pynq_driver_dir + "/finn/util"
diff --git a/src/finn/transformation/fpgadataflow/templates.py b/src/finn/transformation/fpgadataflow/templates.py
index 66580c70d..2b3789dc2 100644
--- a/src/finn/transformation/fpgadataflow/templates.py
+++ b/src/finn/transformation/fpgadataflow/templates.py
@@ -436,3 +436,55 @@ open_project $VITIS_PROJ_PATH$/_x/link/vivado/vpl/prj/prj.xpr
 open_run impl_1
 report_utilization -hierarchical -hierarchical_depth 5 -file $VITIS_PROJ_PATH$/synth_report.xml -format xml
 """
+
+pynq_validation_template = """
+import argparse
+from driver import FINNAccelDriver
+import numpy as np
+
+if __name__ == "__main__":
+  parser = argparse.ArgumentParser(description='Validate top-1 accuracy for FINN accelerator')
+  parser.add_argument('--batchsize', help='number of samples for inference', type=int, default=100)
+  parser.add_argument('--dataset', help='dataset to use (mnist of cifar10)', required=True)
+  # parse arguments
+  args = parser.parse_args()
+  bsize = args.batchsize
+  dataset = args.dataset
+
+  if dataset == "mnist":
+    from dataset_loading import mnist
+    trainx, trainy, testx, testy, valx, valy = mnist.load_mnist_data("/tmp", download=True, one_hot=False)
+  elif dataset == "cifar10":
+    from dataset_loading import cifar
+    trainx, trainy, testx, testy, valx, valy = cifar.load_cifar_data("/tmp", download=True, one_hot=False)
+  else:
+    raise Exception("Unrecognized dataset")
+
+  test_imgs = testx
+  test_labels = testy
+
+  ok = 0
+  nok = 0
+  total = test_imgs.shape[0]
+  driver = FINNAccelDriver(bsize, "resizer.bit", "zynq-iodma")
+
+  n_batches = int(total / bsize)
+
+  test_imgs = test_imgs.reshape(n_batches, bsize, -1)
+  test_labels = test_labels.reshape(n_batches, bsize)
+
+  for i in range(n_batches):
+    ibuf_normal = test_imgs[i].reshape(driver.ibuf_packed_device.shape)
+    exp = test_labels[i]
+    driver.copy_input_data_to_device(ibuf_normal)
+    driver.execute()
+    obuf_normal = np.empty_like(driver.obuf_packed_device)
+    driver.copy_output_data_from_device(obuf_normal)
+    ret = np.bincount(obuf_normal.flatten() == exp.flatten())
+    nok += ret[0]
+    ok += ret[1]
+    print("batch %d / %d : total OK %d NOK %d" % (i, n_batches, ok, nok))
+
+  acc = 100.0 * ok / (total)
+  print("Final accuracy: %f" % acc)
+"""
-- 
GitLab