From 26d64116f7bae64b682cf8ede04f58b276169545 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Fri, 23 Aug 2019 17:29:16 +0000 Subject: [PATCH] try out some basic node execution ONNX and Caffe2 seem to disagree on the output size? --- src/finn/onnx_exec.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/finn/onnx_exec.py b/src/finn/onnx_exec.py index 86b3caa6e..65a564482 100644 --- a/src/finn/onnx_exec.py +++ b/src/finn/onnx_exec.py @@ -27,8 +27,12 @@ import onnx import onnx.helper as helper import numpy as np +import caffe2.python.onnx.backend as be from functools import reduce from onnx import numpy_helper as np_helper + + + model = onnx.load_model("model.onnx") graph = model.graph @@ -38,6 +42,20 @@ def valueinfo_to_tensor(vi): dims = [x.dim_value for x in vi.type.tensor_type.shape.dim] return np.zeros(dims, dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type]) +def execute_node(node, context): + """Call Caffe2 to execute a single node. Input/output provided via context.""" + + input_dict = dict() + for inp in node.input: + input_dict[inp] = context[inp] + print("Input shape for %s: %s" % (inp, context[inp].shape)) + output_dict = be.run_node(node, input_dict) + for outp in node.output: + if output_dict[outp].shape != context[outp].shape: + raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_dict[outp].shape), str(context[outp].shape))) + context[outp] = output_dict[outp] + print("Output shape for %s: %s" % (outp, context[outp].shape)) + # first, we need to make sure that every variable required by the graph has # some buffer associated with it. this includes graph inputs (which includes # the input data as well as the trained parameters) and the graph ValueInfo @@ -70,10 +88,8 @@ for node in graph.node: all_used_ops.add(node.op_type) print("Input(s): " + str(node.input)) print("Output(s): " + str(node.output)) - print("All inputs in context: ") - print(list(map(lambda x: x in execution_context.keys(), node.input))) - print("All outputs in context: ") - print(list(map(lambda x: x in execution_context.keys(), node.output))) + print("Attribute(s): " + str(node.attribute)) + execute_node(node, execution_context) -print("Operators used in this graph: ") -print(all_used_ops) +print("Final output(s): ") +print(execution_context[graph.output[0].name]) -- GitLab