diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh
index 035bba3b53d85a8457eff1e7c1a23e0efff60caa..7ba0aeabbd9bd83c7f33e54cfde626070dcc5ec5 100644
--- a/docker/finn_entrypoint.sh
+++ b/docker/finn_entrypoint.sh
@@ -17,7 +17,7 @@ BREVITAS_COMMIT=215cf44c76d562339fca368c8c3afee3110033e8
 BREVITAS_EXAMPLES_COMMIT=2059f96bd576bf71f32c757e7f92617a70190c90
 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
 HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af
-PYVERILATOR_COMMIT=fb1afefa5b207acf6fec28f8abb72a862f2ca1d2
+PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6
 PYNQSHELL_COMMIT=0c82a61b0ec1a07fa275a14146233824ded7a13d
 
 
diff --git a/src/finn/analysis/topology.py b/src/finn/analysis/topology.py
index c825a221ec178ee89b4e3747c982e59a3005cadd..acdb8ed7fcf41fd041c3601b2ee4fe67b6dc5f19 100644
--- a/src/finn/analysis/topology.py
+++ b/src/finn/analysis/topology.py
@@ -79,3 +79,26 @@ def node_inputs_in_expected_order(model):
         if n.op_type != "Add":
             all_OK = all_OK and (model.get_initializer(n.input[1]) is not None)
     return {"node_inputs_in_expected_order": all_OK}
+
+
+def nodes_topologically_sorted(model):
+    """Verifies that graph.node is topologically sorted. This is required by the
+    ONNX specification.
+
+    Returns {"nodes_topologically_sorted": Bool}."""
+
+    # get successors of every node and check that
+    # successor index > current node index
+
+    all_OK = True
+    for n in model.graph.node:
+        successors = model.find_direct_successors(n)
+        if successors is not None:
+            for successor in successors:
+                # check the condition by checking the antithesis
+                index_n = model.get_node_index(n)
+                index_suc = model.get_node_index(successor)
+                if index_n > index_suc:
+                    all_OK = False
+
+    return {"nodes_topologically_sorted": all_OK}
diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index e99a6ef4cd40d6323d77354d3c9b4be341d7649c..dc5b36920a5639933463d682dc66fb8bc15b35f2 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -288,6 +288,46 @@ class ModelWrapper:
         except ValueError:
             return None
 
+    def find_consumers(self, tensor_name):
+        """Finds and returns a list of the nodes that consume tensor with
+        given name."""
+        consumers = []
+        for n in self._model_proto.graph.node:
+            for inp_tensor in n.input:
+                if inp_tensor == tensor_name:
+                    consumers.append(n)
+        if consumers != []:
+            return consumers
+        else:
+            return None
+
+    def find_direct_successors(self, node):
+        """Finds and returns a list of the nodes that are successors of
+        given node."""
+        successors = []
+        for outp_tensor in node.output:
+            tensor_consumer_list = self.find_consumers(outp_tensor)
+            if tensor_consumer_list is not None:
+                for consumer in tensor_consumer_list:
+                    successors.append(consumer)
+        if successors != []:
+            return successors
+        else:
+            return None
+
+    def find_direct_predecessors(self, node):
+        """Finds and returns a list of the nodes that are predecessors of
+        given node."""
+        predecessors = []
+        for inp_tensor in node.input:
+            producer = self.find_producer(inp_tensor)
+            if producer is not None:
+                predecessors.append(producer)
+        if predecessors != []:
+            return predecessors
+        else:
+            return None
+
     def get_all_tensor_names(self):
         """Returns a list of all (input, output and value_info) tensor names
         in the graph."""
@@ -383,3 +423,14 @@ class ModelWrapper:
     def get_non_finn_nodes(self):
         """Returns a list of nodes where domain != 'finn'."""
         return list(filter(lambda x: x.domain != "finn", self.graph.node))
+
+    def get_node_index(self, node):
+        """Returns current index of given node."""
+        n_ind = 0
+        try:
+            for n in self.graph.node:
+                if n == node:
+                    return n_ind
+                n_ind += 1
+        except ValueError:
+            return None
diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index 172ba25b223fd087df134add460a42d0a9935e0e..44787e1d26049e6075e2222316b45ab3898acbc7 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -38,6 +38,7 @@ from finn.core.modelwrapper import ModelWrapper
 from finn.core.remote_exec import remote_exec
 from finn.core.rtlsim_exec import rtlsim_exec
 from finn.custom_op.registry import getCustomOp
+import finn.analysis.topology as ta
 
 
 def execute_node(node, context, graph):
@@ -121,6 +122,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
 
     if not model.check_all_tensor_shapes_specified():
         raise Exception("Found unspecified tensor shapes, try infer_shapes")
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert (
+        ret["nodes_topologically_sorted"] is True
+    ), """Nodes must be
+    topologically sorted."""
 
     graph = model.graph
     # first, we need to make sure that every variable required by the graph has
diff --git a/tests/analysis/test_topology_checks.py b/tests/analysis/test_topology_checks.py
index 41fbdb6cac8e81d6b1e3eed54a71d0e1d43c3adc..7f7f800da05e38fefa9350928ab6ddc94acbe2b6 100644
--- a/tests/analysis/test_topology_checks.py
+++ b/tests/analysis/test_topology_checks.py
@@ -26,11 +26,13 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
+import os
 from pkgutil import get_data
 
 import onnx.helper as oh
 from onnx import TensorProto
-
+import brevitas.onnx as bo
+from finn.util.test import get_test_model_trained
 import finn.analysis.topology as ta
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
@@ -88,3 +90,116 @@ def test_node_inputs_in_expected_order():
     # this model has an (unnecessary) dynamic reshape for its weight tensor
     # and so it fails the check
     assert ret["node_inputs_in_expected_order"] is False
+
+
+def test_nodes_topologically_sorted():
+    # test analysis pass (nodes_topologically_sorted) with different models
+
+    # test with data/onnx/finn-hls-model/tfc_w1_a1_after_conv_to_hls.onnx
+    raw_m = get_data(
+        "finn", "data/onnx/finn-hls-model/tfc_w1_a1_after_conv_to_hls.onnx"
+    )
+    model = ModelWrapper(raw_m)
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is True
+
+    # remove first node and add it at the end
+    graph = model.graph
+    first_node = graph.node[0]
+    graph.node.remove(first_node)
+    graph.node.append(first_node)
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is False
+
+    # test with data/onnx/mnist-conv/model.onnx
+    raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
+    model = ModelWrapper(raw_m)
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is True
+
+    # remove first node and add it at the end
+    graph = model.graph
+    first_node = graph.node[0]
+    graph.node.remove(first_node)
+    graph.node.append(first_node)
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is False
+
+    # test with manually created small network
+    Neg_node = oh.make_node("Neg", inputs=["in1"], outputs=["neg1"])
+    Round_node = oh.make_node("Round", inputs=["neg1"], outputs=["round1"])
+
+    Ceil_node = oh.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"])
+    Add_node = oh.make_node("Add", inputs=["round1", "ceil1"], outputs=["out1"])
+
+    in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, [4, 4])
+    out1 = oh.make_tensor_value_info("out1", TensorProto.FLOAT, [4, 4])
+
+    graph = oh.make_graph(
+        nodes=[Neg_node, Round_node, Ceil_node, Add_node],
+        name="simple_graph",
+        inputs=[in1],
+        outputs=[out1],
+        value_info=[
+            oh.make_tensor_value_info("neg1", TensorProto.FLOAT, [4, 4]),
+            oh.make_tensor_value_info("round1", TensorProto.FLOAT, [4, 4]),
+            oh.make_tensor_value_info("ceil1", TensorProto.FLOAT, [4, 4]),
+        ],
+    )
+
+    onnx_model = oh.make_model(graph, producer_name="simple-model")
+    model = ModelWrapper(onnx_model)
+
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is True
+
+    # create same graph but with "wrong" node order
+    graph = oh.make_graph(
+        nodes=[Round_node, Ceil_node, Neg_node, Add_node],
+        name="simple_graph",
+        inputs=[in1],
+        outputs=[out1],
+        value_info=[
+            oh.make_tensor_value_info("neg1", TensorProto.FLOAT, [4, 4]),
+            oh.make_tensor_value_info("round1", TensorProto.FLOAT, [4, 4]),
+            oh.make_tensor_value_info("ceil1", TensorProto.FLOAT, [4, 4]),
+        ],
+    )
+
+    onnx_model = oh.make_model(graph, producer_name="simple-model")
+    model = ModelWrapper(onnx_model)
+
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is False
+
+    # test with data/onnx/finn-hls-model/finn-hls-onnx-model.onnx
+    raw_m = get_data("finn", "data/onnx/finn-hls-model/finn-hls-onnx-model.onnx")
+    model = ModelWrapper(raw_m)
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is True
+
+    # remove first node and add it at the end
+    graph = model.graph
+    first_node = graph.node[0]
+    graph.node.remove(first_node)
+    graph.node.append(first_node)
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is False
+
+    # test with cnv_w1a1
+    build_dir = "/tmp/" + os.environ["FINN_INST_NAME"]
+    cnv = get_test_model_trained("CNV", 1, 1)
+    bo.export_finn_onnx(
+        cnv, (1, 3, 32, 32), build_dir + "/end2end_cnv_w1a1_export.onnx"
+    )
+    model = ModelWrapper(build_dir + "/end2end_cnv_w1a1_export.onnx")
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is True
+
+    # remove first node and add it at the end
+    graph = model.graph
+    first_node = graph.node[0]
+    graph.node.remove(first_node)
+    graph.node.append(first_node)
+    ret = model.analysis(ta.nodes_topologically_sorted)
+    assert ret["nodes_topologically_sorted"] is False
diff --git a/tests/brevitas/test_brevitas_act_export.py b/tests/brevitas/test_brevitas_act_export.py
index 08c4a99151d1105ad4258a8d7d6c19cc72da7a99..0415d70bfe6a543b4547f07dd99ff2525bef5994 100644
--- a/tests/brevitas/test_brevitas_act_export.py
+++ b/tests/brevitas/test_brevitas_act_export.py
@@ -1,3 +1,4 @@
+import os
 import numpy as np
 import torch
 import brevitas.onnx as bo
@@ -41,3 +42,4 @@ def test_brevitas_act_export(abits, narrow_range, max_val):
     inp_tensor = torch.from_numpy(inp_tensor).float()
     expected = b_act.forward(inp_tensor).detach().numpy()
     assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)
diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py
index 942eda19ca4c2cdbded9f906a5e7772f50acbd6e..4d2029093f09a705eb562a7f706f21b64172e435 100644
--- a/tests/core/test_modelwrapper.py
+++ b/tests/core/test_modelwrapper.py
@@ -27,8 +27,8 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import os
+import onnx
 from collections import Counter
-
 import brevitas.onnx as bo
 import numpy as np
 
@@ -68,3 +68,56 @@ def test_modelwrapper():
     out_prod = model.find_producer(l0_inp_tensor_name)
     assert out_prod.op_type == "Sign"
     os.remove(export_onnx_path)
+
+
+def test_modelwrapper_graph_order():
+    # create small network with properties to be tested
+    Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"],)
+    Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"],)
+
+    Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"],)
+    Add_node = onnx.helper.make_node(
+        "Add", inputs=["round1", "ceil1"], outputs=["out1"],
+    )
+
+    in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
+    out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
+
+    graph = onnx.helper.make_graph(
+        nodes=[Neg_node, Round_node, Ceil_node, Add_node],
+        name="simple_graph",
+        inputs=[in1],
+        outputs=[out1],
+        value_info=[
+            onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]),
+            onnx.helper.make_tensor_value_info(
+                "round1", onnx.TensorProto.FLOAT, [4, 4]
+            ),
+            onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]),
+        ],
+    )
+
+    onnx_model = onnx.helper.make_model(graph, producer_name="simple-model")
+    model = ModelWrapper(onnx_model)
+
+    # test graph order functions
+    assert model.find_consumers("in1") == [Neg_node]
+    assert model.find_consumers("neg1") == [Round_node, Ceil_node]
+    assert model.find_consumers("round1") == [Add_node]
+    assert model.find_consumers("ceil1") == [Add_node]
+    assert model.find_consumers("out1") is None
+
+    assert model.find_direct_successors(Neg_node) == [Round_node, Ceil_node]
+    assert model.find_direct_successors(Round_node) == [Add_node]
+    assert model.find_direct_successors(Ceil_node) == [Add_node]
+    assert model.find_direct_successors(Add_node) is None
+
+    assert model.find_direct_predecessors(Neg_node) is None
+    assert model.find_direct_predecessors(Round_node) == [Neg_node]
+    assert model.find_direct_predecessors(Ceil_node) == [Neg_node]
+    assert model.find_direct_predecessors(Add_node) == [Round_node, Ceil_node]
+
+    assert model.get_node_index(Neg_node) == 0
+    assert model.get_node_index(Round_node) == 1
+    assert model.get_node_index(Ceil_node) == 2
+    assert model.get_node_index(Add_node) == 3
diff --git a/tests/transformation/test_topk_insert.py b/tests/transformation/test_topk_insert.py
index ac32c30edbbf466b2b441bcc92975a7d50f42bda..1af0f255d8fb1af8a6e571518f18d831aa71298b 100644
--- a/tests/transformation/test_topk_insert.py
+++ b/tests/transformation/test_topk_insert.py
@@ -1,3 +1,4 @@
+import os
 import onnx
 from finn.util.test import get_test_model_trained
 import brevitas.onnx as bo
@@ -56,3 +57,4 @@ def test_topk_insert(k):
     output_pysim_topk = output_pysim_topk.astype(np.int).flatten()
 
     assert np.array_equal(output_golden_topk, output_pysim_topk)
+    os.remove(export_onnx_path)