Skip to content
Snippets Groups Projects
Commit c81a7b55 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add test for subgraph execution

parent c5ce46c4
No related branches found
No related tags found
No related merge requests found
......@@ -49,19 +49,33 @@ def test_mnist_onnx_download_extract_run():
raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
output_tensor = onnx.load_tensor_from_string(raw_o)
# run using FINN-based execution
# run using FINN-based execution (full graph)
input_dict = {"Input3": np_helper.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
output_dict = oxe.execute_onnx(model, input_dict, return_full_exec_context=True)
assert np.isclose(
np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
).all()
# test subgraph execution
start_node = model.graph.node[1]
end_node = model.graph.node[3]
subgraph_i_dict = {start_node.input[0]: output_dict[start_node.input[0]]}
subgraph_o_dict = oxe.execute_onnx(
model,
subgraph_i_dict,
return_full_exec_context=True,
start_node=start_node,
end_node=end_node,
)
assert np.isclose(
subgraph_o_dict[end_node.output[0]], output_dict[end_node.output[0]], atol=1e-3
).all()
def test_onnx_exec_internal_rounding():
inp0 = onnx.helper.make_tensor_value_info("inp0", onnx.TensorProto.FLOAT, [2, 2])
inp1 = onnx.helper.make_tensor_value_info("inp1", onnx.TensorProto.FLOAT, [1])
outp = onnx.helper.make_tensor_value_info("outp", onnx.TensorProto.FLOAT, [2, 2])
mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"],)
mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"])
graph = onnx.helper.make_graph(
nodes=[mul_node], name="mul_graph", inputs=[inp0, inp1], outputs=[outp]
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment