diff --git a/src/finn/transformation/fpgadataflow/insert_hook.py b/src/finn/transformation/fpgadataflow/insert_hook.py
index 35a1fb81996efae34118e005543ce2d8d8514166..22050c008b2991671b8483404e1a6e8772de691e 100644
--- a/src/finn/transformation/fpgadataflow/insert_hook.py
+++ b/src/finn/transformation/fpgadataflow/insert_hook.py
@@ -4,6 +4,7 @@ from onnx import helper as oh
 
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.base import Transformation
+from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
 from finn.util.fpgadataflow import is_fpgadataflow_node
 
 
@@ -44,50 +45,55 @@ class InsertHook(Transformation):
             if _suitable_node(n):
                 for output_name in n.output:
                     consumers = model.find_consumers(output_name)
-                    if consumers == []:
-                        continue
-                    assert len(consumers) == 1, (
+                    assert len(consumers) <= 1, (
                         n.name
                         + ": HLS node with fan-out higher than 1 cannot be stitched"
                     )
-                    consumer = consumers[0]
-                    if _suitable_node(consumer) is True:
-                        n0 = getCustomOp(n)
-                        n0_hook = n0.get_nodeattr("output_hook")
-                        if n0_hook in list_supported_hooks:
-                            if n0_hook == "checksum":
-                                n0_normal_oshape = n0.get_normal_output_shape()
-                                n0_folded_oshape = n0.get_folded_output_shape()
-                                n0_odt = n0.get_output_datatype()
-                                items_per_word = n0.get_nodeattr("PE")
-                                words_per_frame = np.prod(n0_folded_oshape[:-1])
-                                chk_otensor = oh.make_tensor_value_info(
-                                    model.make_new_valueinfo_name(),
-                                    TensorProto.FLOAT,
-                                    n0_normal_oshape,
-                                )
-                                chk_result = oh.make_tensor_value_info(
-                                    model.make_new_valueinfo_name(),
-                                    TensorProto.FLOAT,
-                                    [1],
-                                )
-                                chk_node = oh.make_node(
-                                    "checksum",
-                                    [output_name],
-                                    outputs=[chk_otensor.name, chk_result.name],
-                                    domain="finn.custom_op.fpgadataflow",
-                                    backend="fpgadataflow",
-                                    words_per_frame=words_per_frame,
-                                    items_per_word=items_per_word,
-                                    inputDataType=str(n0_odt.name),
-                                    folded_shape=n0_folded_oshape,
-                                )
-                            # insert dwc
+                    n0 = getCustomOp(n)
+                    n0_hook = n0.get_nodeattr("output_hook")
+                    if n0_hook in list_supported_hooks:
+                        if n0_hook == "checksum":
+                            if len(consumers) == 1:
+                                if consumers[0].op_type == "checksum":
+                                    continue
+                            n0_normal_oshape = n0.get_normal_output_shape()
+                            n0_folded_oshape = n0.get_folded_output_shape()
+                            n0_odt = n0.get_output_datatype()
+                            items_per_word = n0.get_nodeattr("PE")
+                            words_per_frame = np.prod(n0_folded_oshape[:-1])
+                            chk_otensor = oh.make_tensor_value_info(
+                                model.make_new_valueinfo_name(),
+                                TensorProto.FLOAT,
+                                n0_normal_oshape,
+                            )
+                            chk_result = oh.make_tensor_value_info(
+                                model.make_new_valueinfo_name(),
+                                TensorProto.FLOAT,
+                                [1],
+                            )
+                            chk_node = oh.make_node(
+                                "checksum",
+                                [output_name],
+                                outputs=[chk_otensor.name, chk_result.name],
+                                domain="finn.custom_op.fpgadataflow",
+                                backend="fpgadataflow",
+                                words_per_frame=words_per_frame,
+                                items_per_word=items_per_word,
+                                inputDataType=str(n0_odt.name),
+                                folded_shape=n0_folded_oshape,
+                            )
+                            # insert checksum node
                             graph.node.insert(node_ind + 1, chk_node)
 
-                            # set dwc output tensor as new input tensor of second node
-                            for idx, inp in enumerate(consumer.input):
-                                if inp == output_name:
-                                    consumer.input[idx] = chk_otensor.name
+                            # set chk output tensor as new input tensor of second node
+                            if len(consumers) == 1:
+                                consumers[0].input[0] = chk_otensor.name
+                            else:
+                                model.graph.output.pop()
+                                model.graph.output.append(chk_otensor)
+                                model = model.transform(GiveUniqueNodeNames())
+                                model = model.transform(GiveReadableTensorNames())
+                            graph_modified = True
+                            return (model, graph_modified)
 
         return (model, graph_modified)