Skip to content
Snippets Groups Projects
Commit 0105584b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] move StreamingDataflowPartition exec to own file

parent cebb4e2e
No related branches found
No related tags found
No related merge requests found
...@@ -31,9 +31,7 @@ import onnx.helper as helper ...@@ -31,9 +31,7 @@ import onnx.helper as helper
import onnxruntime as rt import onnxruntime as rt
import finn.core.execute_custom_node as ex_cu_node import finn.core.execute_custom_node as ex_cu_node
from finn.core.modelwrapper import ModelWrapper
from finn.core.remote_exec import remote_exec from finn.core.remote_exec import remote_exec
from finn.custom_op.registry import getCustomOp
def execute_node(node, context, graph): def execute_node(node, context, graph):
...@@ -41,54 +39,42 @@ def execute_node(node, context, graph): ...@@ -41,54 +39,42 @@ def execute_node(node, context, graph):
if dataflow partition by using remote execution or rtlsim. if dataflow partition by using remote execution or rtlsim.
Input/output provided via context.""" Input/output provided via context."""
if node.op_type == "StreamingDataflowPartition": if node.domain == "finn":
sdp_node = getCustomOp(node) ex_cu_node.execute_custom_node(node, context, graph)
model = ModelWrapper(sdp_node.get_nodeattr("model"))
execute_onnx(model, context)
else: else:
if node.domain == "finn": # onnxruntime unfortunately does not implement run_node as defined by ONNX,
# it can only execute entire models -- so we create a model which solely
ex_cu_node.execute_custom_node(node, context, graph) # consists of our current node.
node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
node_outputs += list(filter(lambda x: x.name in node.output, graph.value_info))
node_graph = helper.make_graph(
nodes=[node],
name="single-node-exec",
inputs=node_inputs,
outputs=node_outputs,
)
node_model = helper.make_model(node_graph)
input_dict = dict()
for inp in node.input:
input_dict[inp] = context[inp]
else: sess = rt.InferenceSession(node_model.SerializeToString())
output_list = sess.run(None, input_dict)
# onnxruntime unfortunately does not implement run_node as defined by ONNX, for output_ind in range(len(node.output)):
# it can only execute entire models -- so we create a model which solely outp = node.output[output_ind]
# consists of our current node. if output_list[output_ind].shape != context[outp].shape:
node_inputs = list(filter(lambda x: x.name in node.input, graph.input)) raise Exception(
node_inputs += list( """Output shapes disagree after node execution:
filter(lambda x: x.name in node.input, graph.value_info) found %s vs expected %s"""
) % (
node_outputs = list(filter(lambda x: x.name in node.output, graph.output)) str(output_list[output_ind].shape.shape),
node_outputs += list( str(context[outp].shape),
filter(lambda x: x.name in node.output, graph.value_info)
)
node_graph = helper.make_graph(
nodes=[node],
name="single-node-exec",
inputs=node_inputs,
outputs=node_outputs,
)
node_model = helper.make_model(node_graph)
input_dict = dict()
for inp in node.input:
input_dict[inp] = context[inp]
sess = rt.InferenceSession(node_model.SerializeToString())
output_list = sess.run(None, input_dict)
for output_ind in range(len(node.output)):
outp = node.output[output_ind]
if output_list[output_ind].shape != context[outp].shape:
raise Exception(
"""Output shapes disagree after node execution:
found %s vs expected %s"""
% (
str(output_list[output_ind].shape.shape),
str(context[outp].shape),
)
) )
context[outp] = output_list[output_ind] )
context[outp] = output_list[output_ind]
def execute_onnx(model, input_dict, return_full_exec_context=False): def execute_onnx(model, input_dict, return_full_exec_context=False):
......
from finn.core.modelwrapper import ModelWrapper
from finn.core.onnx_exec import execute_onnx
from finn.custom_op import CustomOp from finn.custom_op import CustomOp
# note that the StreamingDataflowPartition node is only a meta/container node, # note that the StreamingDataflowPartition node is only a meta/container node,
...@@ -19,9 +21,9 @@ class StreamingDataflowPartition(CustomOp): ...@@ -19,9 +21,9 @@ class StreamingDataflowPartition(CustomOp):
pass pass
def execute_node(self, context, graph): def execute_node(self, context, graph):
# TODO add RPC execution with synthesized bitfile? # retrieve linked "child" dataflow model and execute
# whole-design rtlsim with PyVerilator may also be an alternative dataflow_model = ModelWrapper(self.get_nodeattr("model"))
pass execute_onnx(dataflow_model, context)
def verify_node(self): def verify_node(self):
info_messages = [] info_messages = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment