From 74c7642e90740b7c29fa790837d4c64fb9920d66 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 20 Apr 2021 14:05:54 +0100
Subject: [PATCH] [Test] rsync mnist dataset into target for extw end2end test

---
 tests/end2end/test_ext_weights.py | 39 ++++++++++++++++++++-----------
 1 file changed, 25 insertions(+), 14 deletions(-)

diff --git a/tests/end2end/test_ext_weights.py b/tests/end2end/test_ext_weights.py
index 0407395ed..aa0ce7a6c 100644
--- a/tests/end2end/test_ext_weights.py
+++ b/tests/end2end/test_ext_weights.py
@@ -44,6 +44,14 @@ onnx_zip_url = "https://github.com/Xilinx/finn-examples"
 onnx_zip_url += "/releases/download/v0.0.1a/onnx-models-bnn-pynq.zip"
 onnx_zip_local = build_dir + "/onnx-models-bnn-pynq.zip"
 onnx_dir_local = build_dir + "/onnx-models-bnn-pynq"
+mnist_url = "https://raw.githubusercontent.com/fgnt/mnist/master"
+mnist_local = build_dir + "/mnist"
+mnist_files = [
+    "train-images-idx3-ubyte.gz",
+    "train-labels-idx1-ubyte.gz",
+    "t10k-images-idx3-ubyte.gz",
+    "t10k-labels-idx1-ubyte.gz",
+]
 
 
 def get_checkpoint_name(step):
@@ -98,6 +106,22 @@ def test_end2end_ext_weights_build():
     shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build"))
 
 
+@pytest.mark.board
+def test_end2end_ext_weights_dataset():
+    # make sure we have local copies of mnist dataset files
+    subprocess.check_output(["mkdir", "-p", mnist_local])
+    for f in mnist_files:
+        if not os.path.isfile(mnist_local + "/" + f):
+            wget.download(mnist_url + "/" + f, out=mnist_local + "/" + f)
+        assert os.path.isfile(mnist_local + "/" + f)
+    # rsync to board
+    build_env = get_build_env(build_kind, target_clk_ns)
+    mnist_target = "%s@%s:%s" % (build_env["username"], build_env["ip"], "/tmp/")
+
+    rsync_dataset_cmd = ["rsync", "-rv", mnist_local + "/", mnist_target]
+    subprocess.check_output(rsync_dataset_cmd)
+
+
 def test_end2end_ext_weights_run_on_hw():
     build_env = get_build_env(build_kind, target_clk_ns)
     deploy_dir = get_checkpoint_name("build")
@@ -124,22 +148,9 @@ echo %s | sudo -S python3.6 validate.py --dataset mnist --bitfile %s
         build_env["ip"],
         build_env["target_dir"],
     )
-    rsync_res = subprocess.run(
-        [
-            "sshpass",
-            "-p",
-            build_env["password"],
-            "rsync",
-            "-avz",
-            deploy_dir,
-            remote_target,
-        ]
-    )
+    rsync_res = subprocess.run(["rsync", "-avz", deploy_dir, remote_target])
     assert rsync_res.returncode == 0
     remote_verif_cmd = [
-        "sshpass",
-        "-p",
-        build_env["password"],
         "ssh",
         "%s@%s" % (build_env["username"], build_env["ip"]),
         "sh",
-- 
GitLab