Skip to content
Snippets Groups Projects
Commit d6fd1dee authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[InsertFIFO] handle multiple graph inputs correctly

parent ea3522ba
No related branches found
No related tags found
No related merge requests found
......@@ -57,21 +57,21 @@ class InsertFIFO(Transformation):
graph = model.graph
node_ind = -1
graph_modified = False
for n in graph.node:
for first_node in graph.node:
node_ind += 1
if _suitable_node(n):
for n_output in n.output:
if _suitable_node(first_node):
for n_output in first_node.output:
consumers = model.find_consumers(n_output)
if consumers is None:
continue
if len(consumers) > 1:
warnings.warn(
n.name
first_node.name
+ ": HLS node with fan-out higher than 1 cannot be stitched"
)
consumer = consumers[0]
if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n0 = getCustomOp(first_node)
# determine fifo node attributes
fld_shape = n0.get_folded_output_shape()
dtype = n0.get_output_datatype()
......@@ -137,47 +137,49 @@ class InsertFIFO(Transformation):
graph_modified = True
if graph_modified is False:
# insert FIFO as first node, except when first node is DMA
if (
graph.node[0].op_type != "StreamingFIFO"
and graph.node[0].op_type != "IODMA"
):
n = graph.node[0]
n_input = n.input[0]
n0 = getCustomOp(n)
# determine fifo node attributes
fld_shape = n0.get_folded_input_shape()
dtype = n0.get_input_datatype()
fifo_depth = n0.get_nodeattr("inFIFODepth")
if fifo_depth <= 2:
warnings.warn("Overriding input FIFO depth to 32")
fifo_depth = 32
# create fifo node
fifo_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0.get_normal_input_shape(),
)
graph.value_info.append(fifo_output_tensor)
model.set_tensor_datatype(fifo_output_tensor.name, dtype)
fifo_node = oh.make_node(
"StreamingFIFO",
[n_input],
[fifo_output_tensor.name],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
depth=fifo_depth,
folded_shape=fld_shape,
dataType=str(dtype.name),
)
# insert fifo
graph.node.insert(0, fifo_node)
# set fifo output tensor as new input tensor of second node
n.input[0] = fifo_output_tensor.name
graph_in_names = [x.name for x in model.graph.input]
for graph_in_name in graph_in_names:
first_node = model.find_consumer(graph_in_name)
# insert FIFO as first node, except when first node is DMA
if (
first_node.op_type != "StreamingFIFO"
and first_node.op_type != "IODMA"
):
n_input = first_node.input[0]
n0 = getCustomOp(first_node)
# determine fifo node attributes
fld_shape = n0.get_folded_input_shape()
dtype = n0.get_input_datatype()
fifo_depth = n0.get_nodeattr("inFIFODepth")
if fifo_depth <= 2:
warnings.warn("Overriding input FIFO depth to 32")
fifo_depth = 32
# create fifo node
fifo_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0.get_normal_input_shape(),
)
graph.value_info.append(fifo_output_tensor)
model.set_tensor_datatype(fifo_output_tensor.name, dtype)
fifo_node = oh.make_node(
"StreamingFIFO",
[n_input],
[fifo_output_tensor.name],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
depth=fifo_depth,
folded_shape=fld_shape,
dataType=str(dtype.name),
)
# insert fifo
graph.node.insert(0, fifo_node)
# set fifo output tensor as new input tensor of second node
first_node.input[0] = fifo_output_tensor.name
# insert FIFO as last node, except when last node is DMA
graph_out_names = [x.name for x in model.graph.output]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment