From 39283948bea68a6abfde166e957d31267ae086c2 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 17 Feb 2020 20:30:35 +0100
Subject: [PATCH] Revert "[Refactor] move StreamingDataflowPartition exec to
 own file"

This reverts commit 0105584ba3fc7275f8a95a0c5c504ec06b65978e.
---
 src/finn/core/onnx_exec.py                    | 78 +++++++++++--------
 .../custom_op/streamingdataflowpartition.py   |  8 +-
 2 files changed, 49 insertions(+), 37 deletions(-)

diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index 57c229663..9786c55c0 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -31,7 +31,9 @@ import onnx.helper as helper
 import onnxruntime as rt
 
 import finn.core.execute_custom_node as ex_cu_node
+from finn.core.modelwrapper import ModelWrapper
 from finn.core.remote_exec import remote_exec
+from finn.custom_op.registry import getCustomOp
 
 
 def execute_node(node, context, graph):
@@ -39,42 +41,54 @@ def execute_node(node, context, graph):
     if dataflow partition by using remote execution or rtlsim.
     Input/output provided via context."""
 
-    if node.domain == "finn":
-        ex_cu_node.execute_custom_node(node, context, graph)
+    if node.op_type == "StreamingDataflowPartition":
+        sdp_node = getCustomOp(node)
+        model = ModelWrapper(sdp_node.get_nodeattr("model"))
+        execute_onnx(model, context)
     else:
-        # onnxruntime unfortunately does not implement run_node as defined by ONNX,
-        # it can only execute entire models -- so we create a model which solely
-        # consists of our current node.
-        node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
-        node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
-        node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
-        node_outputs += list(filter(lambda x: x.name in node.output, graph.value_info))
-        node_graph = helper.make_graph(
-            nodes=[node],
-            name="single-node-exec",
-            inputs=node_inputs,
-            outputs=node_outputs,
-        )
-        node_model = helper.make_model(node_graph)
-        input_dict = dict()
-        for inp in node.input:
-            input_dict[inp] = context[inp]
+        if node.domain == "finn":
 
-        sess = rt.InferenceSession(node_model.SerializeToString())
-        output_list = sess.run(None, input_dict)
+            ex_cu_node.execute_custom_node(node, context, graph)
 
-        for output_ind in range(len(node.output)):
-            outp = node.output[output_ind]
-            if output_list[output_ind].shape != context[outp].shape:
-                raise Exception(
-                    """Output shapes disagree after node execution:
-                    found %s vs expected %s"""
-                    % (
-                        str(output_list[output_ind].shape.shape),
-                        str(context[outp].shape),
+        else:
+
+            # onnxruntime unfortunately does not implement run_node as defined by ONNX,
+            # it can only execute entire models -- so we create a model which solely
+            # consists of our current node.
+            node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
+            node_inputs += list(
+                filter(lambda x: x.name in node.input, graph.value_info)
+            )
+            node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
+            node_outputs += list(
+                filter(lambda x: x.name in node.output, graph.value_info)
+            )
+            node_graph = helper.make_graph(
+                nodes=[node],
+                name="single-node-exec",
+                inputs=node_inputs,
+                outputs=node_outputs,
+            )
+            node_model = helper.make_model(node_graph)
+            input_dict = dict()
+            for inp in node.input:
+                input_dict[inp] = context[inp]
+
+            sess = rt.InferenceSession(node_model.SerializeToString())
+            output_list = sess.run(None, input_dict)
+
+            for output_ind in range(len(node.output)):
+                outp = node.output[output_ind]
+                if output_list[output_ind].shape != context[outp].shape:
+                    raise Exception(
+                        """Output shapes disagree after node execution:
+                        found %s vs expected %s"""
+                        % (
+                            str(output_list[output_ind].shape.shape),
+                            str(context[outp].shape),
+                        )
                     )
-                )
-            context[outp] = output_list[output_ind]
+                context[outp] = output_list[output_ind]
 
 
 def execute_onnx(model, input_dict, return_full_exec_context=False):
diff --git a/src/finn/custom_op/streamingdataflowpartition.py b/src/finn/custom_op/streamingdataflowpartition.py
index b9bb0a02f..856ba6d56 100644
--- a/src/finn/custom_op/streamingdataflowpartition.py
+++ b/src/finn/custom_op/streamingdataflowpartition.py
@@ -1,5 +1,3 @@
-from finn.core.modelwrapper import ModelWrapper
-from finn.core.onnx_exec import execute_onnx
 from finn.custom_op import CustomOp
 
 # note that the StreamingDataflowPartition node is only a meta/container node,
@@ -21,9 +19,9 @@ class StreamingDataflowPartition(CustomOp):
         pass
 
     def execute_node(self, context, graph):
-        # retrieve linked "child" dataflow model and execute
-        dataflow_model = ModelWrapper(self.get_nodeattr("model"))
-        execute_onnx(dataflow_model, context)
+        # TODO add RPC execution with synthesized bitfile?
+        # whole-design rtlsim with PyVerilator may also be an alternative
+        pass
 
     def verify_node(self):
         info_messages = []
-- 
GitLab