From 318d2b3dba0b57f142d0b77a102da4c4d7a75440 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 18 May 2022 19:03:54 +0100
Subject: [PATCH] [Transformation] Add insert hook transformation

---
 src/finn/custom_op/fpgadataflow/checksum.py   | 13 ++-
 .../custom_op/fpgadataflow/hlscustomop.py     |  1 +
 .../fpgadataflow/insert_hook.py               | 93 +++++++++++++++++++
 3 files changed, 100 insertions(+), 7 deletions(-)
 create mode 100644 src/finn/transformation/fpgadataflow/insert_hook.py

diff --git a/src/finn/custom_op/fpgadataflow/checksum.py b/src/finn/custom_op/fpgadataflow/checksum.py
index 24b831cac..dffc90a9e 100644
--- a/src/finn/custom_op/fpgadataflow/checksum.py
+++ b/src/finn/custom_op/fpgadataflow/checksum.py
@@ -206,14 +206,13 @@ class checksum(HLSCustomOp):
         self.code_gen_dict["$GLOBALS$"] = ['#include "checksum.hpp"']
 
     def defines(self, var):
+        items_per_word = self.get_nodeattr("items_per_word")
+        words_per_frame = self.get_nodeattr("words_per_frame")
+        word_size = self.get_instream_width()
         my_defines = []
-        my_defines.append(
-            "#define WORDS_PER_FRAME {}".format(self.get_nodeattr("words_per_frame"))
-        )
-        my_defines.append(
-            "#define ITEMS_PER_WORD {}".format(self.get_nodeattr("items_per_word"))
-        )
-        my_defines.append("#define WORD_SIZE {}".format(self.get_instream_width()))
+        my_defines.append("#define WORDS_PER_FRAME {}".format(words_per_frame))
+        my_defines.append("#define ITEMS_PER_WORD {}".format(items_per_word))
+        my_defines.append("#define WORD_SIZE {}".format(word_size))
         self.code_gen_dict["$DEFINES$"] = my_defines
 
     def read_npy_data(self):
diff --git a/src/finn/custom_op/fpgadataflow/hlscustomop.py b/src/finn/custom_op/fpgadataflow/hlscustomop.py
index 5e6fb1124..ed12d2f1a 100644
--- a/src/finn/custom_op/fpgadataflow/hlscustomop.py
+++ b/src/finn/custom_op/fpgadataflow/hlscustomop.py
@@ -112,6 +112,7 @@ class HLSCustomOp(CustomOp):
             # input and output FIFO depths
             "inFIFODepth": ("i", False, 2),
             "outFIFODepth": ("i", False, 2),
+            "output_hook": ("s", False, ""),
             # HLS version to be used for IP synthesis
             "hls_version": ("s", False, "vitis_hls", {"vivado_hls", "vitis_hls"}),
         }
diff --git a/src/finn/transformation/fpgadataflow/insert_hook.py b/src/finn/transformation/fpgadataflow/insert_hook.py
new file mode 100644
index 000000000..35a1fb819
--- /dev/null
+++ b/src/finn/transformation/fpgadataflow/insert_hook.py
@@ -0,0 +1,93 @@
+import numpy as np
+from onnx import TensorProto
+from onnx import helper as oh
+
+from finn.custom_op.registry import getCustomOp
+from finn.transformation.base import Transformation
+from finn.util.fpgadataflow import is_fpgadataflow_node
+
+
+def _is_hook_node(node):
+    if node.op_type in ["checksum"]:
+        return True
+    else:
+        return False
+
+
+def _suitable_node(node):
+    if node is not None:
+        if is_fpgadataflow_node(node) is True:
+            if _is_hook_node(node) is False:
+                return True
+            else:
+                return False
+        else:
+            return False
+    else:
+        return False
+
+
+class InsertHook(Transformation):
+    """Inserting hook layer after each layer that has the node attribute
+    'output_hook' specified"""
+
+    def __init__(self):
+        super().__init__()
+
+    def apply(self, model):
+        list_supported_hooks = ["checksum"]
+        graph = model.graph
+        node_ind = -1
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if _suitable_node(n):
+                for output_name in n.output:
+                    consumers = model.find_consumers(output_name)
+                    if consumers == []:
+                        continue
+                    assert len(consumers) == 1, (
+                        n.name
+                        + ": HLS node with fan-out higher than 1 cannot be stitched"
+                    )
+                    consumer = consumers[0]
+                    if _suitable_node(consumer) is True:
+                        n0 = getCustomOp(n)
+                        n0_hook = n0.get_nodeattr("output_hook")
+                        if n0_hook in list_supported_hooks:
+                            if n0_hook == "checksum":
+                                n0_normal_oshape = n0.get_normal_output_shape()
+                                n0_folded_oshape = n0.get_folded_output_shape()
+                                n0_odt = n0.get_output_datatype()
+                                items_per_word = n0.get_nodeattr("PE")
+                                words_per_frame = np.prod(n0_folded_oshape[:-1])
+                                chk_otensor = oh.make_tensor_value_info(
+                                    model.make_new_valueinfo_name(),
+                                    TensorProto.FLOAT,
+                                    n0_normal_oshape,
+                                )
+                                chk_result = oh.make_tensor_value_info(
+                                    model.make_new_valueinfo_name(),
+                                    TensorProto.FLOAT,
+                                    [1],
+                                )
+                                chk_node = oh.make_node(
+                                    "checksum",
+                                    [output_name],
+                                    outputs=[chk_otensor.name, chk_result.name],
+                                    domain="finn.custom_op.fpgadataflow",
+                                    backend="fpgadataflow",
+                                    words_per_frame=words_per_frame,
+                                    items_per_word=items_per_word,
+                                    inputDataType=str(n0_odt.name),
+                                    folded_shape=n0_folded_oshape,
+                                )
+                            # insert dwc
+                            graph.node.insert(node_ind + 1, chk_node)
+
+                            # set dwc output tensor as new input tensor of second node
+                            for idx, inp in enumerate(consumer.input):
+                                if inp == output_name:
+                                    consumer.input[idx] = chk_otensor.name
+
+        return (model, graph_modified)
-- 
GitLab