diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index e99a6ef4cd40d6323d77354d3c9b4be341d7649c..dc5b36920a5639933463d682dc66fb8bc15b35f2 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -288,6 +288,46 @@ class ModelWrapper: except ValueError: return None + def find_consumers(self, tensor_name): + """Finds and returns a list of the nodes that consume tensor with + given name.""" + consumers = [] + for n in self._model_proto.graph.node: + for inp_tensor in n.input: + if inp_tensor == tensor_name: + consumers.append(n) + if consumers != []: + return consumers + else: + return None + + def find_direct_successors(self, node): + """Finds and returns a list of the nodes that are successors of + given node.""" + successors = [] + for outp_tensor in node.output: + tensor_consumer_list = self.find_consumers(outp_tensor) + if tensor_consumer_list is not None: + for consumer in tensor_consumer_list: + successors.append(consumer) + if successors != []: + return successors + else: + return None + + def find_direct_predecessors(self, node): + """Finds and returns a list of the nodes that are predecessors of + given node.""" + predecessors = [] + for inp_tensor in node.input: + producer = self.find_producer(inp_tensor) + if producer is not None: + predecessors.append(producer) + if predecessors != []: + return predecessors + else: + return None + def get_all_tensor_names(self): """Returns a list of all (input, output and value_info) tensor names in the graph.""" @@ -383,3 +423,14 @@ class ModelWrapper: def get_non_finn_nodes(self): """Returns a list of nodes where domain != 'finn'.""" return list(filter(lambda x: x.domain != "finn", self.graph.node)) + + def get_node_index(self, node): + """Returns current index of given node.""" + n_ind = 0 + try: + for n in self.graph.node: + if n == node: + return n_ind + n_ind += 1 + except ValueError: + return None diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 942eda19ca4c2cdbded9f906a5e7772f50acbd6e..4d2029093f09a705eb562a7f706f21b64172e435 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -27,8 +27,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +import onnx from collections import Counter - import brevitas.onnx as bo import numpy as np @@ -68,3 +68,56 @@ def test_modelwrapper(): out_prod = model.find_producer(l0_inp_tensor_name) assert out_prod.op_type == "Sign" os.remove(export_onnx_path) + + +def test_modelwrapper_graph_order(): + # create small network with properties to be tested + Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"],) + Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"],) + + Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"],) + Add_node = onnx.helper.make_node( + "Add", inputs=["round1", "ceil1"], outputs=["out1"], + ) + + in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4]) + out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4]) + + graph = onnx.helper.make_graph( + nodes=[Neg_node, Round_node, Ceil_node, Add_node], + name="simple_graph", + inputs=[in1], + outputs=[out1], + value_info=[ + onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info( + "round1", onnx.TensorProto.FLOAT, [4, 4] + ), + onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]), + ], + ) + + onnx_model = onnx.helper.make_model(graph, producer_name="simple-model") + model = ModelWrapper(onnx_model) + + # test graph order functions + assert model.find_consumers("in1") == [Neg_node] + assert model.find_consumers("neg1") == [Round_node, Ceil_node] + assert model.find_consumers("round1") == [Add_node] + assert model.find_consumers("ceil1") == [Add_node] + assert model.find_consumers("out1") is None + + assert model.find_direct_successors(Neg_node) == [Round_node, Ceil_node] + assert model.find_direct_successors(Round_node) == [Add_node] + assert model.find_direct_successors(Ceil_node) == [Add_node] + assert model.find_direct_successors(Add_node) is None + + assert model.find_direct_predecessors(Neg_node) is None + assert model.find_direct_predecessors(Round_node) == [Neg_node] + assert model.find_direct_predecessors(Ceil_node) == [Neg_node] + assert model.find_direct_predecessors(Add_node) == [Round_node, Ceil_node] + + assert model.get_node_index(Neg_node) == 0 + assert model.get_node_index(Round_node) == 1 + assert model.get_node_index(Ceil_node) == 2 + assert model.get_node_index(Add_node) == 3