Skip to content
Snippets Groups Projects
Unverified Commit 8ca3749f authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #87 from Xilinx/feature/graph_order_util_fct

ModelWrapper helper functions to find consumers/producers/direct predecessors and sucessors
parents 5ab626ed 4864a63c
No related branches found
No related tags found
No related merge requests found
......@@ -288,6 +288,46 @@ class ModelWrapper:
except ValueError:
return None
def find_consumers(self, tensor_name):
"""Finds and returns a list of the nodes that consume tensor with
given name."""
consumers = []
for n in self._model_proto.graph.node:
for inp_tensor in n.input:
if inp_tensor == tensor_name:
consumers.append(n)
if consumers != []:
return consumers
else:
return None
def find_direct_successors(self, node):
"""Finds and returns a list of the nodes that are successors of
given node."""
successors = []
for outp_tensor in node.output:
tensor_consumer_list = self.find_consumers(outp_tensor)
if tensor_consumer_list is not None:
for consumer in tensor_consumer_list:
successors.append(consumer)
if successors != []:
return successors
else:
return None
def find_direct_predecessors(self, node):
"""Finds and returns a list of the nodes that are predecessors of
given node."""
predecessors = []
for inp_tensor in node.input:
producer = self.find_producer(inp_tensor)
if producer is not None:
predecessors.append(producer)
if predecessors != []:
return predecessors
else:
return None
def get_all_tensor_names(self):
"""Returns a list of all (input, output and value_info) tensor names
in the graph."""
......@@ -383,3 +423,14 @@ class ModelWrapper:
def get_non_finn_nodes(self):
"""Returns a list of nodes where domain != 'finn'."""
return list(filter(lambda x: x.domain != "finn", self.graph.node))
def get_node_index(self, node):
"""Returns current index of given node."""
n_ind = 0
try:
for n in self.graph.node:
if n == node:
return n_ind
n_ind += 1
except ValueError:
return None
......@@ -27,8 +27,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import onnx
from collections import Counter
import brevitas.onnx as bo
import numpy as np
......@@ -68,3 +68,56 @@ def test_modelwrapper():
out_prod = model.find_producer(l0_inp_tensor_name)
assert out_prod.op_type == "Sign"
os.remove(export_onnx_path)
def test_modelwrapper_graph_order():
# create small network with properties to be tested
Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"],)
Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"],)
Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"],)
Add_node = onnx.helper.make_node(
"Add", inputs=["round1", "ceil1"], outputs=["out1"],
)
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
graph = onnx.helper.make_graph(
nodes=[Neg_node, Round_node, Ceil_node, Add_node],
name="simple_graph",
inputs=[in1],
outputs=[out1],
value_info=[
onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]),
onnx.helper.make_tensor_value_info(
"round1", onnx.TensorProto.FLOAT, [4, 4]
),
onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]),
],
)
onnx_model = onnx.helper.make_model(graph, producer_name="simple-model")
model = ModelWrapper(onnx_model)
# test graph order functions
assert model.find_consumers("in1") == [Neg_node]
assert model.find_consumers("neg1") == [Round_node, Ceil_node]
assert model.find_consumers("round1") == [Add_node]
assert model.find_consumers("ceil1") == [Add_node]
assert model.find_consumers("out1") is None
assert model.find_direct_successors(Neg_node) == [Round_node, Ceil_node]
assert model.find_direct_successors(Round_node) == [Add_node]
assert model.find_direct_successors(Ceil_node) == [Add_node]
assert model.find_direct_successors(Add_node) is None
assert model.find_direct_predecessors(Neg_node) is None
assert model.find_direct_predecessors(Round_node) == [Neg_node]
assert model.find_direct_predecessors(Ceil_node) == [Neg_node]
assert model.find_direct_predecessors(Add_node) == [Round_node, Ceil_node]
assert model.get_node_index(Neg_node) == 0
assert model.get_node_index(Round_node) == 1
assert model.get_node_index(Ceil_node) == 2
assert model.get_node_index(Add_node) == 3
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment