From 4c5fbc626d81afc850469b98ed889dc350d254e6 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Mon, 21 Oct 2019 14:35:06 +0100 Subject: [PATCH] [Wrapper] fix param names for wrapper methods --- src/finn/core/modelwrapper.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index 596c78e34..1ecd46da2 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -34,14 +34,14 @@ class ModelWrapper: # TODO check that all constants are initializers return True - def get_tensor_shape(self): + def get_tensor_shape(self, tensor_name): """Returns the shape of tensor with given name, if it has ValueInfoProto.""" graph = self._model_proto.graph vi_names = [(x.name, x) for x in graph.input] vi_names += [(x.name, x) for x in graph.output] vi_names += [(x.name, x) for x in graph.value_info] try: - vi_ind = [x[0] for x in vi_names].index(self) + vi_ind = [x[0] for x in vi_names].index(tensor_name) vi = vi_names[vi_ind][1] dims = [x.dim_value for x in vi.type.tensor_type.shape.dim] return dims @@ -57,7 +57,7 @@ class ModelWrapper: # first, remove if an initializer already exists init_names = [x.name for x in graph.initializer] try: - init_ind = init_names.index(self) + init_ind = init_names.index(tensor_name) init_old = graph.initializer[init_ind] graph.initializer.remove(init_old) except ValueError: @@ -65,32 +65,32 @@ class ModelWrapper: # create and insert new initializer graph.initializer.append(tensor_init_proto) - def get_initializer(self): + def get_initializer(self, tensor_name): """Get the initializer value for tensor with given name, if any.""" graph = self._model_proto.graph init_names = [x.name for x in graph.initializer] try: - init_ind = init_names.index(self) + init_ind = init_names.index(tensor_name) return np_helper.to_array(graph.initializer[init_ind]) except ValueError: return None - def find_producer(self): + def find_producer(self, tensor_name): """Find and return the node that produces the tensor with given name. Currently only works for linear graphs.""" all_outputs = [x.output[0] for x in self._model_proto.graph.node] try: - producer_ind = all_outputs.index(self) + producer_ind = all_outputs.index(tensor_name) return self._model_proto.graph.node[producer_ind] except ValueError: return None - def find_consumer(self): + def find_consumer(self, tensor_name): """Find and return the node that consumes the tensor with given name. Currently only works for linear graphs.""" all_inputs = [x.input[0] for x in self._model_proto.graph.node] try: - consumer_ind = all_inputs.index(self) + consumer_ind = all_inputs.index(tensor_name) return self._model_proto.graph.node[consumer_ind] except ValueError: return None -- GitLab