From 0105584ba3fc7275f8a95a0c5c504ec06b65978e Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 17 Feb 2020 19:35:34 +0100
Subject: [PATCH] [Refactor] move StreamingDataflowPartition exec to own file

---
 src/finn/core/onnx_exec.py                    | 78 ++++++++-----------
 .../custom_op/streamingdataflowpartition.py   |  8 +-
 2 files changed, 37 insertions(+), 49 deletions(-)

diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index 9786c55c0..57c229663 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -31,9 +31,7 @@ import onnx.helper as helper
 import onnxruntime as rt
 
 import finn.core.execute_custom_node as ex_cu_node
-from finn.core.modelwrapper import ModelWrapper
 from finn.core.remote_exec import remote_exec
-from finn.custom_op.registry import getCustomOp
 
 
 def execute_node(node, context, graph):
@@ -41,54 +39,42 @@ def execute_node(node, context, graph):
     if dataflow partition by using remote execution or rtlsim.
     Input/output provided via context."""
 
-    if node.op_type == "StreamingDataflowPartition":
-        sdp_node = getCustomOp(node)
-        model = ModelWrapper(sdp_node.get_nodeattr("model"))
-        execute_onnx(model, context)
+    if node.domain == "finn":
+        ex_cu_node.execute_custom_node(node, context, graph)
     else:
-        if node.domain == "finn":
-
-            ex_cu_node.execute_custom_node(node, context, graph)
+        # onnxruntime unfortunately does not implement run_node as defined by ONNX,
+        # it can only execute entire models -- so we create a model which solely
+        # consists of our current node.
+        node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
+        node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
+        node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
+        node_outputs += list(filter(lambda x: x.name in node.output, graph.value_info))
+        node_graph = helper.make_graph(
+            nodes=[node],
+            name="single-node-exec",
+            inputs=node_inputs,
+            outputs=node_outputs,
+        )
+        node_model = helper.make_model(node_graph)
+        input_dict = dict()
+        for inp in node.input:
+            input_dict[inp] = context[inp]
 
-        else:
+        sess = rt.InferenceSession(node_model.SerializeToString())
+        output_list = sess.run(None, input_dict)
 
-            # onnxruntime unfortunately does not implement run_node as defined by ONNX,
-            # it can only execute entire models -- so we create a model which solely
-            # consists of our current node.
-            node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
-            node_inputs += list(
-                filter(lambda x: x.name in node.input, graph.value_info)
-            )
-            node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
-            node_outputs += list(
-                filter(lambda x: x.name in node.output, graph.value_info)
-            )
-            node_graph = helper.make_graph(
-                nodes=[node],
-                name="single-node-exec",
-                inputs=node_inputs,
-                outputs=node_outputs,
-            )
-            node_model = helper.make_model(node_graph)
-            input_dict = dict()
-            for inp in node.input:
-                input_dict[inp] = context[inp]
-
-            sess = rt.InferenceSession(node_model.SerializeToString())
-            output_list = sess.run(None, input_dict)
-
-            for output_ind in range(len(node.output)):
-                outp = node.output[output_ind]
-                if output_list[output_ind].shape != context[outp].shape:
-                    raise Exception(
-                        """Output shapes disagree after node execution:
-                        found %s vs expected %s"""
-                        % (
-                            str(output_list[output_ind].shape.shape),
-                            str(context[outp].shape),
-                        )
+        for output_ind in range(len(node.output)):
+            outp = node.output[output_ind]
+            if output_list[output_ind].shape != context[outp].shape:
+                raise Exception(
+                    """Output shapes disagree after node execution:
+                    found %s vs expected %s"""
+                    % (
+                        str(output_list[output_ind].shape.shape),
+                        str(context[outp].shape),
                     )
-                context[outp] = output_list[output_ind]
+                )
+            context[outp] = output_list[output_ind]
 
 
 def execute_onnx(model, input_dict, return_full_exec_context=False):
diff --git a/src/finn/custom_op/streamingdataflowpartition.py b/src/finn/custom_op/streamingdataflowpartition.py
index 856ba6d56..b9bb0a02f 100644
--- a/src/finn/custom_op/streamingdataflowpartition.py
+++ b/src/finn/custom_op/streamingdataflowpartition.py
@@ -1,3 +1,5 @@
+from finn.core.modelwrapper import ModelWrapper
+from finn.core.onnx_exec import execute_onnx
 from finn.custom_op import CustomOp
 
 # note that the StreamingDataflowPartition node is only a meta/container node,
@@ -19,9 +21,9 @@ class StreamingDataflowPartition(CustomOp):
         pass
 
     def execute_node(self, context, graph):
-        # TODO add RPC execution with synthesized bitfile?
-        # whole-design rtlsim with PyVerilator may also be an alternative
-        pass
+        # retrieve linked "child" dataflow model and execute
+        dataflow_model = ModelWrapper(self.get_nodeattr("model"))
+        execute_onnx(dataflow_model, context)
 
     def verify_node(self):
         info_messages = []
-- 
GitLab