Skip to content
Snippets Groups Projects
Commit 4c5fbc62 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Wrapper] fix param names for wrapper methods

parent ca0071be
No related branches found
No related tags found
No related merge requests found
......@@ -34,14 +34,14 @@ class ModelWrapper:
# TODO check that all constants are initializers
return True
def get_tensor_shape(self):
def get_tensor_shape(self, tensor_name):
"""Returns the shape of tensor with given name, if it has ValueInfoProto."""
graph = self._model_proto.graph
vi_names = [(x.name, x) for x in graph.input]
vi_names += [(x.name, x) for x in graph.output]
vi_names += [(x.name, x) for x in graph.value_info]
try:
vi_ind = [x[0] for x in vi_names].index(self)
vi_ind = [x[0] for x in vi_names].index(tensor_name)
vi = vi_names[vi_ind][1]
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return dims
......@@ -57,7 +57,7 @@ class ModelWrapper:
# first, remove if an initializer already exists
init_names = [x.name for x in graph.initializer]
try:
init_ind = init_names.index(self)
init_ind = init_names.index(tensor_name)
init_old = graph.initializer[init_ind]
graph.initializer.remove(init_old)
except ValueError:
......@@ -65,32 +65,32 @@ class ModelWrapper:
# create and insert new initializer
graph.initializer.append(tensor_init_proto)
def get_initializer(self):
def get_initializer(self, tensor_name):
"""Get the initializer value for tensor with given name, if any."""
graph = self._model_proto.graph
init_names = [x.name for x in graph.initializer]
try:
init_ind = init_names.index(self)
init_ind = init_names.index(tensor_name)
return np_helper.to_array(graph.initializer[init_ind])
except ValueError:
return None
def find_producer(self):
def find_producer(self, tensor_name):
"""Find and return the node that produces the tensor with given name.
Currently only works for linear graphs."""
all_outputs = [x.output[0] for x in self._model_proto.graph.node]
try:
producer_ind = all_outputs.index(self)
producer_ind = all_outputs.index(tensor_name)
return self._model_proto.graph.node[producer_ind]
except ValueError:
return None
def find_consumer(self):
def find_consumer(self, tensor_name):
"""Find and return the node that consumes the tensor with given name.
Currently only works for linear graphs."""
all_inputs = [x.input[0] for x in self._model_proto.graph.node]
try:
consumer_ind = all_inputs.index(self)
consumer_ind = all_inputs.index(tensor_name)
return self._model_proto.graph.node[consumer_ind]
except ValueError:
return None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment