diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 839710681640deca01aa40d3ab420016f0e48165..da2b403d977f178ef9b73c758ba93e1e22f40041 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -121,3 +121,45 @@ def test_modelwrapper_graph_order(): assert model.get_node_index(Round_node) == 1 assert model.get_node_index(Ceil_node) == 2 assert model.get_node_index(Add_node) == 3 + + +def test_modelwrapper_detect_forks_n_joins(): + # create small network with properties to be tested + Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"]) + Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"]) + + Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"]) + Add_node = onnx.helper.make_node( + "Add", inputs=["round1", "ceil1"], outputs=["out1"] + ) + + in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4]) + out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4]) + + graph = onnx.helper.make_graph( + nodes=[Neg_node, Round_node, Ceil_node, Add_node], + name="simple_graph", + inputs=[in1], + outputs=[out1], + value_info=[ + onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info( + "round1", onnx.TensorProto.FLOAT, [4, 4] + ), + onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]), + ], + ) + + onnx_model = onnx.helper.make_model(graph, producer_name="simple-model") + model = ModelWrapper(onnx_model) + + # test + assert model.is_fork_node(Neg_node) + assert not model.is_fork_node(Round_node) + assert not model.is_fork_node(Ceil_node) + assert not model.is_fork_node(Add_node) + + assert not model.is_join_node(Neg_node) + assert not model.is_join_node(Round_node) + assert not model.is_join_node(Ceil_node) + assert model.is_join_node(Add_node)