Skip to content
Snippets Groups Projects
Commit 1d37f8fd authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] allow returning top-k from get_golden_io_pair

parent 4398559d
No related branches found
No related tags found
No related merge requests found
......@@ -63,6 +63,7 @@ from finn.util.test import (
get_example_input,
get_trained_network_and_ishape,
execute_parent,
get_topk,
)
from finn.transformation.fpgadataflow.annotate_resources import AnnotateResources
from finn.transformation.infer_data_layouts import InferDataLayouts
......@@ -194,11 +195,13 @@ def get_folding_function(topology, wbits, abits):
raise Exception("Unknown topology/quantization combo for predefined folding")
def get_golden_io_pair(topology, wbits, abits):
def get_golden_io_pair(topology, wbits, abits, return_topk=None):
(model, ishape) = get_trained_network_and_ishape(topology, wbits, abits)
input_tensor_npy = get_example_input(topology)
input_tensor_torch = torch.from_numpy(input_tensor_npy).float()
output_tensor_npy = model.forward(input_tensor_torch).detach().numpy()
if return_topk is not None:
output_tensor_npy = get_topk(output_tensor_npy, k=return_topk)
return (input_tensor_npy, output_tensor_npy)
......@@ -323,7 +326,7 @@ class TestEnd2End:
model.save(cppsim_chkpt)
parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent")
(input_tensor_npy, output_tensor_npy) = get_golden_io_pair(
topology, wbits, abits
topology, wbits, abits, return_topk=1
)
y = execute_parent(parent_chkpt, cppsim_chkpt, input_tensor_npy)
assert np.isclose(y, output_tensor_npy).all()
......@@ -366,7 +369,7 @@ class TestEnd2End:
model.save(rtlsim_chkpt)
parent_chkpt = get_checkpoint_name(topology, wbits, abits, "dataflow_parent")
(input_tensor_npy, output_tensor_npy) = get_golden_io_pair(
topology, wbits, abits
topology, wbits, abits, return_topk=1
)
y = execute_parent(parent_chkpt, rtlsim_chkpt, input_tensor_npy)
model = ModelWrapper(rtlsim_chkpt)
......@@ -438,7 +441,7 @@ class TestEnd2End:
if cfg["ip"] == "":
pytest.skip("PYNQ board IP address not specified")
(input_tensor_npy, output_tensor_npy) = get_golden_io_pair(
topology, wbits, abits
topology, wbits, abits, return_topk=1
)
parent_model = load_test_checkpoint_or_skip(
get_checkpoint_name(topology, wbits, abits, "dataflow_parent")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment