Skip to content
Snippets Groups Projects
Commit 4a8a3d74 authored by auphelia's avatar auphelia
Browse files

[Transformation] Fix bugs in insert_hook transformation

parent 318d2b3d
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ from onnx import helper as oh ...@@ -4,6 +4,7 @@ from onnx import helper as oh
from finn.custom_op.registry import getCustomOp from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation from finn.transformation.base import Transformation
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.util.fpgadataflow import is_fpgadataflow_node from finn.util.fpgadataflow import is_fpgadataflow_node
...@@ -44,50 +45,55 @@ class InsertHook(Transformation): ...@@ -44,50 +45,55 @@ class InsertHook(Transformation):
if _suitable_node(n): if _suitable_node(n):
for output_name in n.output: for output_name in n.output:
consumers = model.find_consumers(output_name) consumers = model.find_consumers(output_name)
if consumers == []: assert len(consumers) <= 1, (
continue
assert len(consumers) == 1, (
n.name n.name
+ ": HLS node with fan-out higher than 1 cannot be stitched" + ": HLS node with fan-out higher than 1 cannot be stitched"
) )
consumer = consumers[0] n0 = getCustomOp(n)
if _suitable_node(consumer) is True: n0_hook = n0.get_nodeattr("output_hook")
n0 = getCustomOp(n) if n0_hook in list_supported_hooks:
n0_hook = n0.get_nodeattr("output_hook") if n0_hook == "checksum":
if n0_hook in list_supported_hooks: if len(consumers) == 1:
if n0_hook == "checksum": if consumers[0].op_type == "checksum":
n0_normal_oshape = n0.get_normal_output_shape() continue
n0_folded_oshape = n0.get_folded_output_shape() n0_normal_oshape = n0.get_normal_output_shape()
n0_odt = n0.get_output_datatype() n0_folded_oshape = n0.get_folded_output_shape()
items_per_word = n0.get_nodeattr("PE") n0_odt = n0.get_output_datatype()
words_per_frame = np.prod(n0_folded_oshape[:-1]) items_per_word = n0.get_nodeattr("PE")
chk_otensor = oh.make_tensor_value_info( words_per_frame = np.prod(n0_folded_oshape[:-1])
model.make_new_valueinfo_name(), chk_otensor = oh.make_tensor_value_info(
TensorProto.FLOAT, model.make_new_valueinfo_name(),
n0_normal_oshape, TensorProto.FLOAT,
) n0_normal_oshape,
chk_result = oh.make_tensor_value_info( )
model.make_new_valueinfo_name(), chk_result = oh.make_tensor_value_info(
TensorProto.FLOAT, model.make_new_valueinfo_name(),
[1], TensorProto.FLOAT,
) [1],
chk_node = oh.make_node( )
"checksum", chk_node = oh.make_node(
[output_name], "checksum",
outputs=[chk_otensor.name, chk_result.name], [output_name],
domain="finn.custom_op.fpgadataflow", outputs=[chk_otensor.name, chk_result.name],
backend="fpgadataflow", domain="finn.custom_op.fpgadataflow",
words_per_frame=words_per_frame, backend="fpgadataflow",
items_per_word=items_per_word, words_per_frame=words_per_frame,
inputDataType=str(n0_odt.name), items_per_word=items_per_word,
folded_shape=n0_folded_oshape, inputDataType=str(n0_odt.name),
) folded_shape=n0_folded_oshape,
# insert dwc )
# insert checksum node
graph.node.insert(node_ind + 1, chk_node) graph.node.insert(node_ind + 1, chk_node)
# set dwc output tensor as new input tensor of second node # set chk output tensor as new input tensor of second node
for idx, inp in enumerate(consumer.input): if len(consumers) == 1:
if inp == output_name: consumers[0].input[0] = chk_otensor.name
consumer.input[idx] = chk_otensor.name else:
model.graph.output.pop()
model.graph.output.append(chk_otensor)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
graph_modified = True
return (model, graph_modified)
return (model, graph_modified) return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment