Skip to content
Snippets Groups Projects
Commit 2a5793b8 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[TEST] Add test for Modelwrapper fork and join detection helpers

parent 1147ae27
No related branches found
No related tags found
No related merge requests found
......@@ -121,3 +121,45 @@ def test_modelwrapper_graph_order():
assert model.get_node_index(Round_node) == 1
assert model.get_node_index(Ceil_node) == 2
assert model.get_node_index(Add_node) == 3
def test_modelwrapper_detect_forks_n_joins():
# create small network with properties to be tested
Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"])
Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"])
Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"])
Add_node = onnx.helper.make_node(
"Add", inputs=["round1", "ceil1"], outputs=["out1"]
)
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
graph = onnx.helper.make_graph(
nodes=[Neg_node, Round_node, Ceil_node, Add_node],
name="simple_graph",
inputs=[in1],
outputs=[out1],
value_info=[
onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]),
onnx.helper.make_tensor_value_info(
"round1", onnx.TensorProto.FLOAT, [4, 4]
),
onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]),
],
)
onnx_model = onnx.helper.make_model(graph, producer_name="simple-model")
model = ModelWrapper(onnx_model)
# test
assert model.is_fork_node(Neg_node)
assert not model.is_fork_node(Round_node)
assert not model.is_fork_node(Ceil_node)
assert not model.is_fork_node(Add_node)
assert not model.is_join_node(Neg_node)
assert not model.is_join_node(Round_node)
assert not model.is_join_node(Ceil_node)
assert model.is_join_node(Add_node)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment