Skip to content
Snippets Groups Projects
Commit dd11cf2d authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[ModelWrapper] add find_upstream helper function

parent 28474e96
No related branches found
No related tags found
No related merge requests found
......@@ -262,6 +262,24 @@ class ModelWrapper:
except ValueError:
return None
def find_upstream(self, tensor_name, finder_fxn):
"""Follow the producer chain upstream, calling finder_fxn on each upstream
node until it returns True or there are no nodes left. Returns the list
of nodes visited, or None if finder_fxn did not return True."""
visit_list = []
current_tensor = tensor_name
while True:
current_producer = self.find_producer(current_tensor)
if current_producer is None:
return []
else:
found = finder_fxn(current_producer)
visit_list.append(current_producer)
if found:
return visit_list
else:
current_tensor = current_producer.input[0]
def find_consumer(self, tensor_name):
"""Finds and returns the node that consumes the tensor with given name.
Currently only works for linear graphs."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment