Skip to content
Snippets Groups Projects
Commit 4a8a3d74 authored by auphelia's avatar auphelia
Browse files

[Transformation] Fix bugs in insert_hook transformation

parent 318d2b3d
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ from onnx import helper as oh
from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.util.fpgadataflow import is_fpgadataflow_node
......@@ -44,50 +45,55 @@ class InsertHook(Transformation):
if _suitable_node(n):
for output_name in n.output:
consumers = model.find_consumers(output_name)
if consumers == []:
continue
assert len(consumers) == 1, (
assert len(consumers) <= 1, (
n.name
+ ": HLS node with fan-out higher than 1 cannot be stitched"
)
consumer = consumers[0]
if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n0_hook = n0.get_nodeattr("output_hook")
if n0_hook in list_supported_hooks:
if n0_hook == "checksum":
n0_normal_oshape = n0.get_normal_output_shape()
n0_folded_oshape = n0.get_folded_output_shape()
n0_odt = n0.get_output_datatype()
items_per_word = n0.get_nodeattr("PE")
words_per_frame = np.prod(n0_folded_oshape[:-1])
chk_otensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_normal_oshape,
)
chk_result = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
[1],
)
chk_node = oh.make_node(
"checksum",
[output_name],
outputs=[chk_otensor.name, chk_result.name],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
words_per_frame=words_per_frame,
items_per_word=items_per_word,
inputDataType=str(n0_odt.name),
folded_shape=n0_folded_oshape,
)
# insert dwc
n0 = getCustomOp(n)
n0_hook = n0.get_nodeattr("output_hook")
if n0_hook in list_supported_hooks:
if n0_hook == "checksum":
if len(consumers) == 1:
if consumers[0].op_type == "checksum":
continue
n0_normal_oshape = n0.get_normal_output_shape()
n0_folded_oshape = n0.get_folded_output_shape()
n0_odt = n0.get_output_datatype()
items_per_word = n0.get_nodeattr("PE")
words_per_frame = np.prod(n0_folded_oshape[:-1])
chk_otensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_normal_oshape,
)
chk_result = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
[1],
)
chk_node = oh.make_node(
"checksum",
[output_name],
outputs=[chk_otensor.name, chk_result.name],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
words_per_frame=words_per_frame,
items_per_word=items_per_word,
inputDataType=str(n0_odt.name),
folded_shape=n0_folded_oshape,
)
# insert checksum node
graph.node.insert(node_ind + 1, chk_node)
# set dwc output tensor as new input tensor of second node
for idx, inp in enumerate(consumer.input):
if inp == output_name:
consumer.input[idx] = chk_otensor.name
# set chk output tensor as new input tensor of second node
if len(consumers) == 1:
consumers[0].input[0] = chk_otensor.name
else:
model.graph.output.pop()
model.graph.output.append(chk_otensor)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
graph_modified = True
return (model, graph_modified)
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment