Skip to content
Snippets Groups Projects
Commit 4d820d0d authored by Lucian Petrică's avatar Lucian Petrică
Browse files

Various fixes. Verified for one scenario

parent 73b4930c
No related branches found
No related tags found
No related merge requests found
......@@ -253,14 +253,12 @@ class ModelWrapper:
return None
def find_producer(self, tensor_name):
"""Finds and returns the node that produces the tensor with given name.
Currently only works for linear graphs."""
all_outputs = [x.output[0] for x in self._model_proto.graph.node]
try:
producer_ind = all_outputs.index(tensor_name)
return self._model_proto.graph.node[producer_ind]
except ValueError:
return None
"""Finds and returns the node that produces the tensor with given name."""
ret = None
for x in self._model_proto.graph.node:
if tensor_name in x.output:
ret = x
return ret
def find_upstream(self, tensor_name, finder_fxn):
"""Follow the producer chain upstream, calling finder_fxn on each upstream
......
......@@ -84,17 +84,25 @@ def execute_node(node, context, graph):
output_list = sess.run(None, input_dict)
for output_ind in range(len(node.output)):
#get the name of the target buffer from node.output
outp = node.output[output_ind]
if output_list[output_ind].shape != context[outp].shape:
#retrieve the index of that name in node_outputs
for i in range(len(node_outputs)):
if outp == node_outputs[i].name:
list_ind = i
#use that index to index output_list
if output_list[list_ind].shape != context[outp].shape:
raise Exception(
"""Output shapes disagree after node execution:
found %s vs expected %s"""
% (
str(output_list[output_ind].shape.shape),
str(output_list[list_ind].shape.shape),
str(context[outp].shape),
)
)
context[outp] = output_list[output_ind]
context[outp] = output_list[list_ind]
def execute_onnx(model, input_dict, return_full_exec_context=False):
......
......@@ -58,7 +58,6 @@ class InsertTopK(Transformation):
out_dtype = model.get_tensor_datatype(graph_out_name)
#adjust shape
out_shape[self.axis] = self.k
import pdb; pdb.set_trace()
# make new buffer
k_tensor = oh.make_tensor(name='k_tensor',
data_type=TensorProto.INT64,
......@@ -77,8 +76,6 @@ class InsertTopK(Transformation):
model.set_tensor_datatype(k_value.name, out_dtype)#TODO set to int64
model.graph.value_info.append(topk_values)
model.set_tensor_datatype(topk_values.name, out_dtype)
model.graph.value_info.append(topk_indices)
model.set_tensor_datatype(topk_indices.name, out_dtype)
#create and append topk node
k_node = oh.make_node(
'Constant',
......@@ -91,11 +88,13 @@ class InsertTopK(Transformation):
inputs=[graph_out_name, k_value.name],
outputs=[topk_values.name, topk_indices.name],
axis=self.axis,
largest=self.largest,
sorted=self.sorted
)
model.graph.node.append(k_node)
model.graph.node.append(topk_node)
model.graph.output[0].name = topk_values.name
print(topk_indices.name,topk_values.name)
#replace the existing output definition with topk indices
model.graph.output.insert(0,topk_indices)
model.graph.output.pop(1)
return (model, True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment