Skip to content
Snippets Groups Projects
Unverified Commit a3553b57 authored by Tobi-Alonso's avatar Tobi-Alonso Committed by GitHub
Browse files

[fpgadataflow] Add suport for non linear graphs in InsertDWC transf (#227)

parent 2fcb3f1b
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ from onnx import helper as oh ...@@ -4,6 +4,7 @@ from onnx import helper as oh
from finn.custom_op.registry import getCustomOp from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation from finn.transformation.base import Transformation
from finn.util.fpgadataflow import is_fpgadataflow_node from finn.util.fpgadataflow import is_fpgadataflow_node
import warnings
def _is_dwc_node(node): def _is_dwc_node(node):
...@@ -40,48 +41,59 @@ class InsertDWC(Transformation): ...@@ -40,48 +41,59 @@ class InsertDWC(Transformation):
for n in graph.node: for n in graph.node:
node_ind += 1 node_ind += 1
if _suitable_node(n): if _suitable_node(n):
n_output = n.output[0] for n_output in n.output:
consumer = model.find_consumer(n_output) consumers = model.find_consumers(n_output)
if _suitable_node(consumer) is True: if consumers is None:
n0 = getCustomOp(n) continue
n1 = getCustomOp(consumer) if len(consumers) > 1:
n0_out_shape = n0.get_folded_output_shape() warnings.warn(
n1_in_shape = n1.get_folded_input_shape() n.name
if n0_out_shape[-1] != n1_in_shape[-1]: + ": HLS node with fan-out higher than 1 cannot be stitched"
graph_modified = True
# determine dwc inwidth
dwc_in_width = n0.get_outstream_width()
# determine dwc outwidth
dwc_out_width = n1.get_instream_width()
# determine shape for dwc
dwc_shape = n0.get_normal_output_shape()
# determine dtype for dwc
dtype = n0.get_output_datatype()
dwc_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
dwc_shape,
) )
graph.value_info.append(dwc_output_tensor)
dwc_node = oh.make_node(
"StreamingDataWidthConverter_Batch",
[n_output],
[dwc_output_tensor.name],
domain="finn",
backend="fpgadataflow",
shape=dwc_shape,
inWidth=dwc_in_width,
outWidth=dwc_out_width,
dataType=str(dtype.name),
)
# insert dwc
graph.node.insert(node_ind + 1, dwc_node)
# set dwc output tensor as new input tensor of second node consumer = consumers[0]
consumer.input[0] = dwc_output_tensor.name if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n1 = getCustomOp(consumer)
n0_out_shape = n0.get_folded_output_shape()
n1_in_shape = n1.get_folded_input_shape()
if n0_out_shape[-1] != n1_in_shape[-1]:
graph_modified = True
# determine dwc inwidth
dwc_in_width = n0.get_outstream_width()
# determine dwc outwidth
dwc_out_width = n1.get_instream_width()
# determine shape for dwc
dwc_shape = n0.get_normal_output_shape()
# determine dtype for dwc
dtype = n0.get_output_datatype()
dwc_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
dwc_shape,
)
graph.value_info.append(dwc_output_tensor)
dwc_node = oh.make_node(
"StreamingDataWidthConverter_Batch",
[n_output],
[dwc_output_tensor.name],
domain="finn",
backend="fpgadataflow",
shape=dwc_shape,
inWidth=dwc_in_width,
outWidth=dwc_out_width,
dataType=str(dtype.name),
)
# insert dwc
graph.node.insert(node_ind + 1, dwc_node)
# set dwc output tensor as new input tensor of second node
for idx, inp in enumerate(consumer.input):
if inp == n_output:
consumer.input[idx] = dwc_output_tensor.name
return (model, graph_modified) return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment