diff --git a/tests/transformation/test_topk_insert.py b/tests/transformation/test_topk_insert.py index ded98f32aed636191e5cfbec4362a4e55a9c313a..ac32c30edbbf466b2b441bcc92975a7d50f42bda 100644 --- a/tests/transformation/test_topk_insert.py +++ b/tests/transformation/test_topk_insert.py @@ -19,39 +19,40 @@ import pytest export_onnx_path = "test_output_lfc.onnx" -@pytest.mark.parametrize("k", [1,5,10]) + +@pytest.mark.parametrize("k", [1, 5, 10]) def test_topk_insert(k): tfc = get_test_model_trained("TFC", 1, 1) bo.export_finn_onnx(tfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) - #do transformations (no topk) + # do transformations (no topk) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveReadableTensorNames()) model = model.transform(InferDataTypes()) - #verification: generate random input, run through net, streamline, run again, check that output is top-k + # verification: generate random input, run through net, streamline, + # run again, check that output is top-k raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) input_brevitas = torch.from_numpy(nph.to_array(input_tensor)).float() output_golden = tfc.forward(input_brevitas).detach().numpy() output_golden_topk = np.flip(output_golden.flatten().argsort())[:k] + output_golden_topk = output_golden_topk.flatten() input_dict = {"global_in": nph.to_array(input_tensor)} - output_dict = oxe.execute_onnx(model, input_dict) - output_pysim = output_dict[list(output_dict.keys())[0]] - #insert top-k + # insert top-k model = model.transform(InsertTopK(k)) model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveReadableTensorNames()) model = model.transform(InferShapes()) - #verify output of top-k + # verify output of top-k output_dict_topk = oxe.execute_onnx(model, input_dict) output_pysim_topk = output_dict_topk[list(output_dict_topk.keys())[0]] + output_pysim_topk = output_pysim_topk.astype(np.int).flatten() - assert np.array_equal(output_golden_topk.flatten(), output_pysim_topk.astype(np.int).flatten()) - + assert np.array_equal(output_golden_topk, output_pysim_topk)