Skip to content
Snippets Groups Projects
Commit 0105584b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] move StreamingDataflowPartition exec to own file

parent cebb4e2e
No related branches found
No related tags found
No related merge requests found
......@@ -31,9 +31,7 @@ import onnx.helper as helper
import onnxruntime as rt
import finn.core.execute_custom_node as ex_cu_node
from finn.core.modelwrapper import ModelWrapper
from finn.core.remote_exec import remote_exec
from finn.custom_op.registry import getCustomOp
def execute_node(node, context, graph):
......@@ -41,54 +39,42 @@ def execute_node(node, context, graph):
if dataflow partition by using remote execution or rtlsim.
Input/output provided via context."""
if node.op_type == "StreamingDataflowPartition":
sdp_node = getCustomOp(node)
model = ModelWrapper(sdp_node.get_nodeattr("model"))
execute_onnx(model, context)
if node.domain == "finn":
ex_cu_node.execute_custom_node(node, context, graph)
else:
if node.domain == "finn":
ex_cu_node.execute_custom_node(node, context, graph)
# onnxruntime unfortunately does not implement run_node as defined by ONNX,
# it can only execute entire models -- so we create a model which solely
# consists of our current node.
node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
node_outputs += list(filter(lambda x: x.name in node.output, graph.value_info))
node_graph = helper.make_graph(
nodes=[node],
name="single-node-exec",
inputs=node_inputs,
outputs=node_outputs,
)
node_model = helper.make_model(node_graph)
input_dict = dict()
for inp in node.input:
input_dict[inp] = context[inp]
else:
sess = rt.InferenceSession(node_model.SerializeToString())
output_list = sess.run(None, input_dict)
# onnxruntime unfortunately does not implement run_node as defined by ONNX,
# it can only execute entire models -- so we create a model which solely
# consists of our current node.
node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
node_inputs += list(
filter(lambda x: x.name in node.input, graph.value_info)
)
node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
node_outputs += list(
filter(lambda x: x.name in node.output, graph.value_info)
)
node_graph = helper.make_graph(
nodes=[node],
name="single-node-exec",
inputs=node_inputs,
outputs=node_outputs,
)
node_model = helper.make_model(node_graph)
input_dict = dict()
for inp in node.input:
input_dict[inp] = context[inp]
sess = rt.InferenceSession(node_model.SerializeToString())
output_list = sess.run(None, input_dict)
for output_ind in range(len(node.output)):
outp = node.output[output_ind]
if output_list[output_ind].shape != context[outp].shape:
raise Exception(
"""Output shapes disagree after node execution:
found %s vs expected %s"""
% (
str(output_list[output_ind].shape.shape),
str(context[outp].shape),
)
for output_ind in range(len(node.output)):
outp = node.output[output_ind]
if output_list[output_ind].shape != context[outp].shape:
raise Exception(
"""Output shapes disagree after node execution:
found %s vs expected %s"""
% (
str(output_list[output_ind].shape.shape),
str(context[outp].shape),
)
context[outp] = output_list[output_ind]
)
context[outp] = output_list[output_ind]
def execute_onnx(model, input_dict, return_full_exec_context=False):
......
from finn.core.modelwrapper import ModelWrapper
from finn.core.onnx_exec import execute_onnx
from finn.custom_op import CustomOp
# note that the StreamingDataflowPartition node is only a meta/container node,
......@@ -19,9 +21,9 @@ class StreamingDataflowPartition(CustomOp):
pass
def execute_node(self, context, graph):
# TODO add RPC execution with synthesized bitfile?
# whole-design rtlsim with PyVerilator may also be an alternative
pass
# retrieve linked "child" dataflow model and execute
dataflow_model = ModelWrapper(self.get_nodeattr("model"))
execute_onnx(dataflow_model, context)
def verify_node(self):
info_messages = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment