From 4c5fbc626d81afc850469b98ed889dc350d254e6 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Mon, 21 Oct 2019 14:35:06 +0100
Subject: [PATCH] [Wrapper] fix param names for wrapper methods

---
 src/finn/core/modelwrapper.py | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index 596c78e34..1ecd46da2 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -34,14 +34,14 @@ class ModelWrapper:
         # TODO check that all constants are initializers
         return True
 
-    def get_tensor_shape(self):
+    def get_tensor_shape(self, tensor_name):
         """Returns the shape of tensor with given name, if it has ValueInfoProto."""
         graph = self._model_proto.graph
         vi_names = [(x.name, x) for x in graph.input]
         vi_names += [(x.name, x) for x in graph.output]
         vi_names += [(x.name, x) for x in graph.value_info]
         try:
-            vi_ind = [x[0] for x in vi_names].index(self)
+            vi_ind = [x[0] for x in vi_names].index(tensor_name)
             vi = vi_names[vi_ind][1]
             dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
             return dims
@@ -57,7 +57,7 @@ class ModelWrapper:
         # first, remove if an initializer already exists
         init_names = [x.name for x in graph.initializer]
         try:
-            init_ind = init_names.index(self)
+            init_ind = init_names.index(tensor_name)
             init_old = graph.initializer[init_ind]
             graph.initializer.remove(init_old)
         except ValueError:
@@ -65,32 +65,32 @@ class ModelWrapper:
         # create and insert new initializer
         graph.initializer.append(tensor_init_proto)
 
-    def get_initializer(self):
+    def get_initializer(self, tensor_name):
         """Get the initializer value for tensor with given name, if any."""
         graph = self._model_proto.graph
         init_names = [x.name for x in graph.initializer]
         try:
-            init_ind = init_names.index(self)
+            init_ind = init_names.index(tensor_name)
             return np_helper.to_array(graph.initializer[init_ind])
         except ValueError:
             return None
 
-    def find_producer(self):
+    def find_producer(self, tensor_name):
         """Find and return the node that produces the tensor with given name.
         Currently only works for linear graphs."""
         all_outputs = [x.output[0] for x in self._model_proto.graph.node]
         try:
-            producer_ind = all_outputs.index(self)
+            producer_ind = all_outputs.index(tensor_name)
             return self._model_proto.graph.node[producer_ind]
         except ValueError:
             return None
 
-    def find_consumer(self):
+    def find_consumer(self, tensor_name):
         """Find and return the node that consumes the tensor with given name.
         Currently only works for linear graphs."""
         all_inputs = [x.input[0] for x in self._model_proto.graph.node]
         try:
-            consumer_ind = all_inputs.index(self)
+            consumer_ind = all_inputs.index(tensor_name)
             return self._model_proto.graph.node[consumer_ind]
         except ValueError:
             return None
-- 
GitLab