From 1d37f8fd6d86db82a247061831032cad468574ff Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 14 Sep 2020 21:12:45 +0200
Subject: [PATCH] [Test] allow returning top-k from get_golden_io_pair

---
 tests/end2end/test_end2end_bnn_pynq.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/tests/end2end/test_end2end_bnn_pynq.py b/tests/end2end/test_end2end_bnn_pynq.py
index c3fe1841e..c47076089 100644
--- a/tests/end2end/test_end2end_bnn_pynq.py
+++ b/tests/end2end/test_end2end_bnn_pynq.py
@@ -63,6 +63,7 @@ from finn.util.test import (
     get_example_input,
     get_trained_network_and_ishape,
     execute_parent,
+    get_topk,
 )
 from finn.transformation.fpgadataflow.annotate_resources import AnnotateResources
 from finn.transformation.infer_data_layouts import InferDataLayouts
@@ -194,11 +195,13 @@ def get_folding_function(topology, wbits, abits):
         raise Exception("Unknown topology/quantization combo for predefined folding")
 
 
-def get_golden_io_pair(topology, wbits, abits):
+def get_golden_io_pair(topology, wbits, abits, return_topk=None):
     (model, ishape) = get_trained_network_and_ishape(topology, wbits, abits)
     input_tensor_npy = get_example_input(topology)
     input_tensor_torch = torch.from_numpy(input_tensor_npy).float()
     output_tensor_npy = model.forward(input_tensor_torch).detach().numpy()
+    if return_topk is not None:
+        output_tensor_npy = get_topk(output_tensor_npy, k=return_topk)
     return (input_tensor_npy, output_tensor_npy)
 
 
@@ -323,7 +326,7 @@ class TestEnd2End:
         model.save(cppsim_chkpt)
         parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent")
         (input_tensor_npy, output_tensor_npy) = get_golden_io_pair(
-            topology, wbits, abits
+            topology, wbits, abits, return_topk=1
         )
         y = execute_parent(parent_chkpt, cppsim_chkpt, input_tensor_npy)
         assert np.isclose(y, output_tensor_npy).all()
@@ -366,7 +369,7 @@ class TestEnd2End:
         model.save(rtlsim_chkpt)
         parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent")
         (input_tensor_npy, output_tensor_npy) = get_golden_io_pair(
-            topology, wbits, abits
+            topology, wbits, abits, return_topk=1
         )
         y = execute_parent(parent_chkpt, rtlsim_chkpt, input_tensor_npy)
         model = ModelWrapper(rtlsim_chkpt)
@@ -438,7 +441,7 @@ class TestEnd2End:
         if cfg["ip"] == "":
             pytest.skip("PYNQ board IP address not specified")
         (input_tensor_npy, output_tensor_npy) = get_golden_io_pair(
-            topology, wbits, abits
+            topology, wbits, abits, return_topk=1
         )
         parent_model = load_test_checkpoint_or_skip(
             get_checkpoint_name(topology, wbits, abits, "dataflow_parent")
-- 
GitLab