diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py index 278b5ee168afa07b0dcbe053744205b760b92b91..172ba25b223fd087df134add460a42d0a9935e0e 100644 --- a/src/finn/core/onnx_exec.py +++ b/src/finn/core/onnx_exec.py @@ -61,6 +61,10 @@ def execute_node(node, context, graph): # onnxruntime unfortunately does not implement run_node as defined by ONNX, # it can only execute entire models -- so we create a model which solely # consists of our current node. + # note: ensure that the same ValueInfo does not appear both in + # graph.value_info as well as graph.output or graph.input + # nodes with multiple outputs that are a mix of value_info and + # input/outputs may get them reordered below node_inputs = list(filter(lambda x: x.name in node.input, graph.input)) node_inputs += list( filter(lambda x: x.name in node.input, graph.value_info) @@ -84,15 +88,15 @@ def execute_node(node, context, graph): output_list = sess.run(None, input_dict) for output_ind in range(len(node.output)): - #get the name of the target buffer from node.output + # get the name of the target buffer from node.output outp = node.output[output_ind] - #retrieve the index of that name in node_outputs + # retrieve the index of that name in node_outputs for i in range(len(node_outputs)): if outp == node_outputs[i].name: list_ind = i - #use that index to index output_list + # use that index to index output_list if output_list[list_ind].shape != context[outp].shape: raise Exception( """Output shapes disagree after node execution: