From 4a7551d061c81b1afae54bbedb466f46d93b1ce5 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 23 Aug 2019 16:16:17 +0000
Subject: [PATCH] traverse ONNX graph, create buffers, list op types

---
 src/finn/onnx_exec.py | 87 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 87 insertions(+)
 create mode 100644 src/finn/onnx_exec.py

diff --git a/src/finn/onnx_exec.py b/src/finn/onnx_exec.py
new file mode 100644
index 000000000..b7b98caea
--- /dev/null
+++ b/src/finn/onnx_exec.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2019, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+#    1. Redistributions of source code must retain the above copyright
+#       notice, this list of conditions and the following disclaimer.
+#    2. Redistributions in binary form must reproduce the above copyright
+#       notice, this list of conditions and the following disclaimer in the
+#       documentation and/or other materials provided with the distribution.
+#    3. Neither the name of the <organization> nor the
+#       names of its contributors may be used to endorse or promote products
+#       derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
+# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import onnx
+import onnx.helper as helper
+import numpy as np
+from functools import reduce
+from onnx import numpy_helper as np_helper
+model = onnx.load_model("model.onnx")
+graph = model.graph
+
+def valueinfo_to_tensor(vi):
+  """Creates an empty TensorProto from a ValueInfoProto."""
+
+  dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
+  n_elems = reduce(lambda x,y: x*y, dims, 1)
+  return helper.make_tensor(
+    name = vi.name,
+    #data_type = vi.type.tensor_type.elem_type,
+    data_type = 1,
+    dims = dims,
+    vals = np.zeros((n_elems,)),  # TODO always float32 for now - respect type?
+    raw = False
+  )
+
+# first, we need to make sure that every variable required by the graph has
+# some buffer associated with it. this includes graph inputs (which includes
+# the input data as well as the trained parameters) and the graph ValueInfo
+# (intermediate tensors between layers)
+# we'll keep all our buffers in this dict here:
+execution_context = dict()
+# make empty tensors for all the graph inputs and outputs
+for vi in graph.input:
+  new_tensor = valueinfo_to_tensor(vi)
+  execution_context[new_tensor.name] = new_tensor
+for vi in graph.output:
+  new_tensor = valueinfo_to_tensor(vi)
+  execution_context[new_tensor.name] = new_tensor
+# make empty tensors for all intermediate buffers
+# TODO are we guaranteed to have the .value_info filled?
+# do we need to call ONNX shape inference first?
+for vi in graph.value_info:
+  new_tensor = valueinfo_to_tensor(vi)
+  execution_context[new_tensor.name] = new_tensor
+# fill in the constants provided by the initializers
+for t in graph.initializer:
+  execution_context[t.name] = t
+
+# now call each node in the graph nodes list
+# we can simply walk down the list since the ONNX spec guarantees that it is
+# topologically sorted
+all_used_ops = set()
+for node in graph.node:
+  print("Node name: %s Type: %s" % (node.name, node.op_type))
+  all_used_ops.add(node.op_type)
+  print("Input(s): " + str(node.input))
+  print("Output(s): " + str(node.output))
+  print("All inputs in context: ")
+  print(list(map(lambda x: x in execution_context.keys(), node.input)))
+  print("All outputs in context: ")
+  print(list(map(lambda x: x in execution_context.keys(), node.output)))
+
+print("Operators used in this graph: ")
+print(all_used_ops)
-- 
GitLab