Skip to content
Snippets Groups Projects
Unverified Commit a3553b57 authored by Tobi-Alonso's avatar Tobi-Alonso Committed by GitHub
Browse files

[fpgadataflow] Add suport for non linear graphs in InsertDWC transf (#227)

parent 2fcb3f1b
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ from onnx import helper as oh
from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation
from finn.util.fpgadataflow import is_fpgadataflow_node
import warnings
def _is_dwc_node(node):
......@@ -40,48 +41,59 @@ class InsertDWC(Transformation):
for n in graph.node:
node_ind += 1
if _suitable_node(n):
n_output = n.output[0]
consumer = model.find_consumer(n_output)
if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n1 = getCustomOp(consumer)
n0_out_shape = n0.get_folded_output_shape()
n1_in_shape = n1.get_folded_input_shape()
if n0_out_shape[-1] != n1_in_shape[-1]:
graph_modified = True
# determine dwc inwidth
dwc_in_width = n0.get_outstream_width()
# determine dwc outwidth
dwc_out_width = n1.get_instream_width()
# determine shape for dwc
dwc_shape = n0.get_normal_output_shape()
# determine dtype for dwc
dtype = n0.get_output_datatype()
dwc_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
dwc_shape,
for n_output in n.output:
consumers = model.find_consumers(n_output)
if consumers is None:
continue
if len(consumers) > 1:
warnings.warn(
n.name
+ ": HLS node with fan-out higher than 1 cannot be stitched"
)
graph.value_info.append(dwc_output_tensor)
dwc_node = oh.make_node(
"StreamingDataWidthConverter_Batch",
[n_output],
[dwc_output_tensor.name],
domain="finn",
backend="fpgadataflow",
shape=dwc_shape,
inWidth=dwc_in_width,
outWidth=dwc_out_width,
dataType=str(dtype.name),
)
# insert dwc
graph.node.insert(node_ind + 1, dwc_node)
# set dwc output tensor as new input tensor of second node
consumer.input[0] = dwc_output_tensor.name
consumer = consumers[0]
if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n1 = getCustomOp(consumer)
n0_out_shape = n0.get_folded_output_shape()
n1_in_shape = n1.get_folded_input_shape()
if n0_out_shape[-1] != n1_in_shape[-1]:
graph_modified = True
# determine dwc inwidth
dwc_in_width = n0.get_outstream_width()
# determine dwc outwidth
dwc_out_width = n1.get_instream_width()
# determine shape for dwc
dwc_shape = n0.get_normal_output_shape()
# determine dtype for dwc
dtype = n0.get_output_datatype()
dwc_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
dwc_shape,
)
graph.value_info.append(dwc_output_tensor)
dwc_node = oh.make_node(
"StreamingDataWidthConverter_Batch",
[n_output],
[dwc_output_tensor.name],
domain="finn",
backend="fpgadataflow",
shape=dwc_shape,
inWidth=dwc_in_width,
outWidth=dwc_out_width,
dataType=str(dtype.name),
)
# insert dwc
graph.node.insert(node_ind + 1, dwc_node)
# set dwc output tensor as new input tensor of second node
for idx, inp in enumerate(consumer.input):
if inp == n_output:
consumer.input[idx] = dwc_output_tensor.name
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment