diff --git a/src/finn/analysis/topology.py b/src/finn/analysis/topology.py index c825a221ec178ee89b4e3747c982e59a3005cadd..acdb8ed7fcf41fd041c3601b2ee4fe67b6dc5f19 100644 --- a/src/finn/analysis/topology.py +++ b/src/finn/analysis/topology.py @@ -79,3 +79,26 @@ def node_inputs_in_expected_order(model): if n.op_type != "Add": all_OK = all_OK and (model.get_initializer(n.input[1]) is not None) return {"node_inputs_in_expected_order": all_OK} + + +def nodes_topologically_sorted(model): + """Verifies that graph.node is topologically sorted. This is required by the + ONNX specification. + + Returns {"nodes_topologically_sorted": Bool}.""" + + # get successors of every node and check that + # successor index > current node index + + all_OK = True + for n in model.graph.node: + successors = model.find_direct_successors(n) + if successors is not None: + for successor in successors: + # check the condition by checking the antithesis + index_n = model.get_node_index(n) + index_suc = model.get_node_index(successor) + if index_n > index_suc: + all_OK = False + + return {"nodes_topologically_sorted": all_OK} diff --git a/tests/analysis/test_topology_checks.py b/tests/analysis/test_topology_checks.py index 41fbdb6cac8e81d6b1e3eed54a71d0e1d43c3adc..7f7f800da05e38fefa9350928ab6ddc94acbe2b6 100644 --- a/tests/analysis/test_topology_checks.py +++ b/tests/analysis/test_topology_checks.py @@ -26,11 +26,13 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os from pkgutil import get_data import onnx.helper as oh from onnx import TensorProto - +import brevitas.onnx as bo +from finn.util.test import get_test_model_trained import finn.analysis.topology as ta from finn.core.modelwrapper import ModelWrapper from finn.transformation.infer_shapes import InferShapes @@ -88,3 +90,116 @@ def test_node_inputs_in_expected_order(): # this model has an (unnecessary) dynamic reshape for its weight tensor # and so it fails the check assert ret["node_inputs_in_expected_order"] is False + + +def test_nodes_topologically_sorted(): + # test analysis pass (nodes_topologically_sorted) with different models + + # test with data/onnx/finn-hls-model/tfc_w1_a1_after_conv_to_hls.onnx + raw_m = get_data( + "finn", "data/onnx/finn-hls-model/tfc_w1_a1_after_conv_to_hls.onnx" + ) + model = ModelWrapper(raw_m) + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is True + + # remove first node and add it at the end + graph = model.graph + first_node = graph.node[0] + graph.node.remove(first_node) + graph.node.append(first_node) + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is False + + # test with data/onnx/mnist-conv/model.onnx + raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") + model = ModelWrapper(raw_m) + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is True + + # remove first node and add it at the end + graph = model.graph + first_node = graph.node[0] + graph.node.remove(first_node) + graph.node.append(first_node) + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is False + + # test with manually created small network + Neg_node = oh.make_node("Neg", inputs=["in1"], outputs=["neg1"]) + Round_node = oh.make_node("Round", inputs=["neg1"], outputs=["round1"]) + + Ceil_node = oh.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"]) + Add_node = oh.make_node("Add", inputs=["round1", "ceil1"], outputs=["out1"]) + + in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, [4, 4]) + out1 = oh.make_tensor_value_info("out1", TensorProto.FLOAT, [4, 4]) + + graph = oh.make_graph( + nodes=[Neg_node, Round_node, Ceil_node, Add_node], + name="simple_graph", + inputs=[in1], + outputs=[out1], + value_info=[ + oh.make_tensor_value_info("neg1", TensorProto.FLOAT, [4, 4]), + oh.make_tensor_value_info("round1", TensorProto.FLOAT, [4, 4]), + oh.make_tensor_value_info("ceil1", TensorProto.FLOAT, [4, 4]), + ], + ) + + onnx_model = oh.make_model(graph, producer_name="simple-model") + model = ModelWrapper(onnx_model) + + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is True + + # create same graph but with "wrong" node order + graph = oh.make_graph( + nodes=[Round_node, Ceil_node, Neg_node, Add_node], + name="simple_graph", + inputs=[in1], + outputs=[out1], + value_info=[ + oh.make_tensor_value_info("neg1", TensorProto.FLOAT, [4, 4]), + oh.make_tensor_value_info("round1", TensorProto.FLOAT, [4, 4]), + oh.make_tensor_value_info("ceil1", TensorProto.FLOAT, [4, 4]), + ], + ) + + onnx_model = oh.make_model(graph, producer_name="simple-model") + model = ModelWrapper(onnx_model) + + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is False + + # test with data/onnx/finn-hls-model/finn-hls-onnx-model.onnx + raw_m = get_data("finn", "data/onnx/finn-hls-model/finn-hls-onnx-model.onnx") + model = ModelWrapper(raw_m) + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is True + + # remove first node and add it at the end + graph = model.graph + first_node = graph.node[0] + graph.node.remove(first_node) + graph.node.append(first_node) + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is False + + # test with cnv_w1a1 + build_dir = "/tmp/" + os.environ["FINN_INST_NAME"] + cnv = get_test_model_trained("CNV", 1, 1) + bo.export_finn_onnx( + cnv, (1, 3, 32, 32), build_dir + "/end2end_cnv_w1a1_export.onnx" + ) + model = ModelWrapper(build_dir + "/end2end_cnv_w1a1_export.onnx") + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is True + + # remove first node and add it at the end + graph = model.graph + first_node = graph.node[0] + graph.node.remove(first_node) + graph.node.append(first_node) + ret = model.analysis(ta.nodes_topologically_sorted) + assert ret["nodes_topologically_sorted"] is False