From 26d64116f7bae64b682cf8ede04f58b276169545 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 23 Aug 2019 17:29:16 +0000
Subject: [PATCH] try out some basic node execution ONNX and Caffe2 seem to
 disagree on the output size?

---
 src/finn/onnx_exec.py | 28 ++++++++++++++++++++++------
 1 file changed, 22 insertions(+), 6 deletions(-)

diff --git a/src/finn/onnx_exec.py b/src/finn/onnx_exec.py
index 86b3caa6e..65a564482 100644
--- a/src/finn/onnx_exec.py
+++ b/src/finn/onnx_exec.py
@@ -27,8 +27,12 @@
 import onnx
 import onnx.helper as helper
 import numpy as np
+import caffe2.python.onnx.backend as be
 from functools import reduce
 from onnx import numpy_helper as np_helper
+
+
+
 model = onnx.load_model("model.onnx")
 graph = model.graph
 
@@ -38,6 +42,20 @@ def valueinfo_to_tensor(vi):
   dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
   return np.zeros(dims, dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type])
 
+def execute_node(node, context):
+  """Call Caffe2 to execute a single node. Input/output provided via context."""
+
+  input_dict = dict()
+  for inp in node.input:
+    input_dict[inp] = context[inp]
+    print("Input shape for %s: %s" % (inp, context[inp].shape))
+  output_dict = be.run_node(node, input_dict)
+  for outp in node.output:
+    if output_dict[outp].shape != context[outp].shape:
+      raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_dict[outp].shape), str(context[outp].shape)))
+    context[outp] = output_dict[outp]
+    print("Output shape for %s: %s" % (outp, context[outp].shape))
+
 # first, we need to make sure that every variable required by the graph has
 # some buffer associated with it. this includes graph inputs (which includes
 # the input data as well as the trained parameters) and the graph ValueInfo
@@ -70,10 +88,8 @@ for node in graph.node:
   all_used_ops.add(node.op_type)
   print("Input(s): " + str(node.input))
   print("Output(s): " + str(node.output))
-  print("All inputs in context: ")
-  print(list(map(lambda x: x in execution_context.keys(), node.input)))
-  print("All outputs in context: ")
-  print(list(map(lambda x: x in execution_context.keys(), node.output)))
+  print("Attribute(s): " + str(node.attribute))
+  execute_node(node, execution_context)
 
-print("Operators used in this graph: ")
-print(all_used_ops)
+print("Final output(s): ")
+print(execution_context[graph.output[0].name])
-- 
GitLab