From 1d37f8fd6d86db82a247061831032cad468574ff Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Mon, 14 Sep 2020 21:12:45 +0200 Subject: [PATCH] [Test] allow returning top-k from get_golden_io_pair --- tests/end2end/test_end2end_bnn_pynq.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/end2end/test_end2end_bnn_pynq.py b/tests/end2end/test_end2end_bnn_pynq.py index c3fe1841e..c47076089 100644 --- a/tests/end2end/test_end2end_bnn_pynq.py +++ b/tests/end2end/test_end2end_bnn_pynq.py @@ -63,6 +63,7 @@ from finn.util.test import ( get_example_input, get_trained_network_and_ishape, execute_parent, + get_topk, ) from finn.transformation.fpgadataflow.annotate_resources import AnnotateResources from finn.transformation.infer_data_layouts import InferDataLayouts @@ -194,11 +195,13 @@ def get_folding_function(topology, wbits, abits): raise Exception("Unknown topology/quantization combo for predefined folding") -def get_golden_io_pair(topology, wbits, abits): +def get_golden_io_pair(topology, wbits, abits, return_topk=None): (model, ishape) = get_trained_network_and_ishape(topology, wbits, abits) input_tensor_npy = get_example_input(topology) input_tensor_torch = torch.from_numpy(input_tensor_npy).float() output_tensor_npy = model.forward(input_tensor_torch).detach().numpy() + if return_topk is not None: + output_tensor_npy = get_topk(output_tensor_npy, k=return_topk) return (input_tensor_npy, output_tensor_npy) @@ -323,7 +326,7 @@ class TestEnd2End: model.save(cppsim_chkpt) parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent") (input_tensor_npy, output_tensor_npy) = get_golden_io_pair( - topology, wbits, abits + topology, wbits, abits, return_topk=1 ) y = execute_parent(parent_chkpt, cppsim_chkpt, input_tensor_npy) assert np.isclose(y, output_tensor_npy).all() @@ -366,7 +369,7 @@ class TestEnd2End: model.save(rtlsim_chkpt) parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent") (input_tensor_npy, output_tensor_npy) = get_golden_io_pair( - topology, wbits, abits + topology, wbits, abits, return_topk=1 ) y = execute_parent(parent_chkpt, rtlsim_chkpt, input_tensor_npy) model = ModelWrapper(rtlsim_chkpt) @@ -438,7 +441,7 @@ class TestEnd2End: if cfg["ip"] == "": pytest.skip("PYNQ board IP address not specified") (input_tensor_npy, output_tensor_npy) = get_golden_io_pair( - topology, wbits, abits + topology, wbits, abits, return_topk=1 ) parent_model = load_test_checkpoint_or_skip( get_checkpoint_name(topology, wbits, abits, "dataflow_parent") -- GitLab