From be7a79a7fb7425cbd8b38c406d2b462a07ddcd1a Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 16 Sep 2020 22:00:28 +0200
Subject: [PATCH] [Core] pull in exec context from SDP execution if desired

---
 src/finn/core/onnx_exec.py | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index 15e2a69cd..85b52c0f3 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -42,7 +42,7 @@ import finn.analysis.topology as ta
 from finn.util.basic import sanitize_quant_values, get_sanitize_quant_tensors
 
 
-def execute_node(node, context, graph):
+def execute_node(node, context, graph, return_full_exec_context=False):
     """Executes a single node by using onnxruntime, with custom function or
     if dataflow partition by using remote execution or rtlsim.
 
@@ -59,16 +59,21 @@ def execute_node(node, context, graph):
         if old_iname != new_iname:
             inp_ctx[new_iname] = inp_ctx[old_iname]
             del inp_ctx[old_iname]
-        ret = execute_onnx(model, inp_ctx, False)
+        ret = execute_onnx(model, inp_ctx, return_full_exec_context)
         # if the model was in ip-stitched rtlsim mode, may get annotation
         # for numbet of elapsed cycles, save again
         if model.get_metadata_prop("exec_mode") == "rtlsim":
             model.save(sdp_node.get_nodeattr("model"))
         # output may have been renamed in partition
-        assert len(ret) == 1
+        assert len(model.graph.output) == 1
         node_oname = node.output[0]
         model_oname = model.graph.output[0].name
         context[node_oname] = ret[model_oname]
+        # prefix and insert exec context entries
+        if return_full_exec_context:
+            for tname in ret.keys():
+                if tname != model_oname:
+                    context[node.name + "_" + tname] = ret[tname]
     else:
         if node.domain == "finn":
 
@@ -198,7 +203,7 @@ def execute_onnx(
                 execution_context = sanitize_quant_values(
                     model, node.input, execution_context
                 )
-            execute_node(node, execution_context, graph)
+            execute_node(node, execution_context, graph, return_full_exec_context)
             if get_sanitize_quant_tensors() != 0:
                 # round output values to quantization annotation
                 execution_context = sanitize_quant_values(
-- 
GitLab