Skip to content
Snippets Groups Projects
Commit 4c5fbc62 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Wrapper] fix param names for wrapper methods

parent ca0071be
No related branches found
No related tags found
No related merge requests found
...@@ -34,14 +34,14 @@ class ModelWrapper: ...@@ -34,14 +34,14 @@ class ModelWrapper:
# TODO check that all constants are initializers # TODO check that all constants are initializers
return True return True
def get_tensor_shape(self): def get_tensor_shape(self, tensor_name):
"""Returns the shape of tensor with given name, if it has ValueInfoProto.""" """Returns the shape of tensor with given name, if it has ValueInfoProto."""
graph = self._model_proto.graph graph = self._model_proto.graph
vi_names = [(x.name, x) for x in graph.input] vi_names = [(x.name, x) for x in graph.input]
vi_names += [(x.name, x) for x in graph.output] vi_names += [(x.name, x) for x in graph.output]
vi_names += [(x.name, x) for x in graph.value_info] vi_names += [(x.name, x) for x in graph.value_info]
try: try:
vi_ind = [x[0] for x in vi_names].index(self) vi_ind = [x[0] for x in vi_names].index(tensor_name)
vi = vi_names[vi_ind][1] vi = vi_names[vi_ind][1]
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim] dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return dims return dims
...@@ -57,7 +57,7 @@ class ModelWrapper: ...@@ -57,7 +57,7 @@ class ModelWrapper:
# first, remove if an initializer already exists # first, remove if an initializer already exists
init_names = [x.name for x in graph.initializer] init_names = [x.name for x in graph.initializer]
try: try:
init_ind = init_names.index(self) init_ind = init_names.index(tensor_name)
init_old = graph.initializer[init_ind] init_old = graph.initializer[init_ind]
graph.initializer.remove(init_old) graph.initializer.remove(init_old)
except ValueError: except ValueError:
...@@ -65,32 +65,32 @@ class ModelWrapper: ...@@ -65,32 +65,32 @@ class ModelWrapper:
# create and insert new initializer # create and insert new initializer
graph.initializer.append(tensor_init_proto) graph.initializer.append(tensor_init_proto)
def get_initializer(self): def get_initializer(self, tensor_name):
"""Get the initializer value for tensor with given name, if any.""" """Get the initializer value for tensor with given name, if any."""
graph = self._model_proto.graph graph = self._model_proto.graph
init_names = [x.name for x in graph.initializer] init_names = [x.name for x in graph.initializer]
try: try:
init_ind = init_names.index(self) init_ind = init_names.index(tensor_name)
return np_helper.to_array(graph.initializer[init_ind]) return np_helper.to_array(graph.initializer[init_ind])
except ValueError: except ValueError:
return None return None
def find_producer(self): def find_producer(self, tensor_name):
"""Find and return the node that produces the tensor with given name. """Find and return the node that produces the tensor with given name.
Currently only works for linear graphs.""" Currently only works for linear graphs."""
all_outputs = [x.output[0] for x in self._model_proto.graph.node] all_outputs = [x.output[0] for x in self._model_proto.graph.node]
try: try:
producer_ind = all_outputs.index(self) producer_ind = all_outputs.index(tensor_name)
return self._model_proto.graph.node[producer_ind] return self._model_proto.graph.node[producer_ind]
except ValueError: except ValueError:
return None return None
def find_consumer(self): def find_consumer(self, tensor_name):
"""Find and return the node that consumes the tensor with given name. """Find and return the node that consumes the tensor with given name.
Currently only works for linear graphs.""" Currently only works for linear graphs."""
all_inputs = [x.input[0] for x in self._model_proto.graph.node] all_inputs = [x.input[0] for x in self._model_proto.graph.node]
try: try:
consumer_ind = all_inputs.index(self) consumer_ind = all_inputs.index(tensor_name)
return self._model_proto.graph.node[consumer_ind] return self._model_proto.graph.node[consumer_ind]
except ValueError: except ValueError:
return None return None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment