diff --git a/tests/end2end/test_end2end_bnn_pynq.py b/tests/end2end/test_end2end_bnn_pynq.py index c3fe1841e63d6ea743e9d81abe52340256e47366..c470760895d6233ea8633368489c1e49c2e6904e 100644 --- a/tests/end2end/test_end2end_bnn_pynq.py +++ b/tests/end2end/test_end2end_bnn_pynq.py @@ -63,6 +63,7 @@ from finn.util.test import ( get_example_input, get_trained_network_and_ishape, execute_parent, + get_topk, ) from finn.transformation.fpgadataflow.annotate_resources import AnnotateResources from finn.transformation.infer_data_layouts import InferDataLayouts @@ -194,11 +195,13 @@ def get_folding_function(topology, wbits, abits): raise Exception("Unknown topology/quantization combo for predefined folding") -def get_golden_io_pair(topology, wbits, abits): +def get_golden_io_pair(topology, wbits, abits, return_topk=None): (model, ishape) = get_trained_network_and_ishape(topology, wbits, abits) input_tensor_npy = get_example_input(topology) input_tensor_torch = torch.from_numpy(input_tensor_npy).float() output_tensor_npy = model.forward(input_tensor_torch).detach().numpy() + if return_topk is not None: + output_tensor_npy = get_topk(output_tensor_npy, k=return_topk) return (input_tensor_npy, output_tensor_npy) @@ -323,7 +326,7 @@ class TestEnd2End: model.save(cppsim_chkpt) parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent") (input_tensor_npy, output_tensor_npy) = get_golden_io_pair( - topology, wbits, abits + topology, wbits, abits, return_topk=1 ) y = execute_parent(parent_chkpt, cppsim_chkpt, input_tensor_npy) assert np.isclose(y, output_tensor_npy).all() @@ -366,7 +369,7 @@ class TestEnd2End: model.save(rtlsim_chkpt) parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent") (input_tensor_npy, output_tensor_npy) = get_golden_io_pair( - topology, wbits, abits + topology, wbits, abits, return_topk=1 ) y = execute_parent(parent_chkpt, rtlsim_chkpt, input_tensor_npy) model = ModelWrapper(rtlsim_chkpt) @@ -438,7 +441,7 @@ class TestEnd2End: if cfg["ip"] == "": pytest.skip("PYNQ board IP address not specified") (input_tensor_npy, output_tensor_npy) = get_golden_io_pair( - topology, wbits, abits + topology, wbits, abits, return_topk=1 ) parent_model = load_test_checkpoint_or_skip( get_checkpoint_name(topology, wbits, abits, "dataflow_parent")