diff --git a/src/finn/transformation/fpgadataflow/insert_hook.py b/src/finn/transformation/fpgadataflow/insert_hook.py index 35a1fb81996efae34118e005543ce2d8d8514166..22050c008b2991671b8483404e1a6e8772de691e 100644 --- a/src/finn/transformation/fpgadataflow/insert_hook.py +++ b/src/finn/transformation/fpgadataflow/insert_hook.py @@ -4,6 +4,7 @@ from onnx import helper as oh from finn.custom_op.registry import getCustomOp from finn.transformation.base import Transformation +from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames from finn.util.fpgadataflow import is_fpgadataflow_node @@ -44,50 +45,55 @@ class InsertHook(Transformation): if _suitable_node(n): for output_name in n.output: consumers = model.find_consumers(output_name) - if consumers == []: - continue - assert len(consumers) == 1, ( + assert len(consumers) <= 1, ( n.name + ": HLS node with fan-out higher than 1 cannot be stitched" ) - consumer = consumers[0] - if _suitable_node(consumer) is True: - n0 = getCustomOp(n) - n0_hook = n0.get_nodeattr("output_hook") - if n0_hook in list_supported_hooks: - if n0_hook == "checksum": - n0_normal_oshape = n0.get_normal_output_shape() - n0_folded_oshape = n0.get_folded_output_shape() - n0_odt = n0.get_output_datatype() - items_per_word = n0.get_nodeattr("PE") - words_per_frame = np.prod(n0_folded_oshape[:-1]) - chk_otensor = oh.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - n0_normal_oshape, - ) - chk_result = oh.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - [1], - ) - chk_node = oh.make_node( - "checksum", - [output_name], - outputs=[chk_otensor.name, chk_result.name], - domain="finn.custom_op.fpgadataflow", - backend="fpgadataflow", - words_per_frame=words_per_frame, - items_per_word=items_per_word, - inputDataType=str(n0_odt.name), - folded_shape=n0_folded_oshape, - ) - # insert dwc + n0 = getCustomOp(n) + n0_hook = n0.get_nodeattr("output_hook") + if n0_hook in list_supported_hooks: + if n0_hook == "checksum": + if len(consumers) == 1: + if consumers[0].op_type == "checksum": + continue + n0_normal_oshape = n0.get_normal_output_shape() + n0_folded_oshape = n0.get_folded_output_shape() + n0_odt = n0.get_output_datatype() + items_per_word = n0.get_nodeattr("PE") + words_per_frame = np.prod(n0_folded_oshape[:-1]) + chk_otensor = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + n0_normal_oshape, + ) + chk_result = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + [1], + ) + chk_node = oh.make_node( + "checksum", + [output_name], + outputs=[chk_otensor.name, chk_result.name], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + words_per_frame=words_per_frame, + items_per_word=items_per_word, + inputDataType=str(n0_odt.name), + folded_shape=n0_folded_oshape, + ) + # insert checksum node graph.node.insert(node_ind + 1, chk_node) - # set dwc output tensor as new input tensor of second node - for idx, inp in enumerate(consumer.input): - if inp == output_name: - consumer.input[idx] = chk_otensor.name + # set chk output tensor as new input tensor of second node + if len(consumers) == 1: + consumers[0].input[0] = chk_otensor.name + else: + model.graph.output.pop() + model.graph.output.append(chk_otensor) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + graph_modified = True + return (model, graph_modified) return (model, graph_modified)