diff --git a/src/finn/custom_op/fpgadataflow/checksum.py b/src/finn/custom_op/fpgadataflow/checksum.py index 24b831cac83fede2c4061c8dda7f87c2c7cbca44..dffc90a9e3c5f71b59362433ce6f914131d9d434 100644 --- a/src/finn/custom_op/fpgadataflow/checksum.py +++ b/src/finn/custom_op/fpgadataflow/checksum.py @@ -206,14 +206,13 @@ class checksum(HLSCustomOp): self.code_gen_dict["$GLOBALS$"] = ['#include "checksum.hpp"'] def defines(self, var): + items_per_word = self.get_nodeattr("items_per_word") + words_per_frame = self.get_nodeattr("words_per_frame") + word_size = self.get_instream_width() my_defines = [] - my_defines.append( - "#define WORDS_PER_FRAME {}".format(self.get_nodeattr("words_per_frame")) - ) - my_defines.append( - "#define ITEMS_PER_WORD {}".format(self.get_nodeattr("items_per_word")) - ) - my_defines.append("#define WORD_SIZE {}".format(self.get_instream_width())) + my_defines.append("#define WORDS_PER_FRAME {}".format(words_per_frame)) + my_defines.append("#define ITEMS_PER_WORD {}".format(items_per_word)) + my_defines.append("#define WORD_SIZE {}".format(word_size)) self.code_gen_dict["$DEFINES$"] = my_defines def read_npy_data(self): diff --git a/src/finn/custom_op/fpgadataflow/hlscustomop.py b/src/finn/custom_op/fpgadataflow/hlscustomop.py index 5e6fb1124ebd98b22d9d56908c59be7782f18aa4..ed12d2f1af7fcbba019d8896bedfb67ef03847b0 100644 --- a/src/finn/custom_op/fpgadataflow/hlscustomop.py +++ b/src/finn/custom_op/fpgadataflow/hlscustomop.py @@ -112,6 +112,7 @@ class HLSCustomOp(CustomOp): # input and output FIFO depths "inFIFODepth": ("i", False, 2), "outFIFODepth": ("i", False, 2), + "output_hook": ("s", False, ""), # HLS version to be used for IP synthesis "hls_version": ("s", False, "vitis_hls", {"vivado_hls", "vitis_hls"}), } diff --git a/src/finn/transformation/fpgadataflow/insert_hook.py b/src/finn/transformation/fpgadataflow/insert_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..35a1fb81996efae34118e005543ce2d8d8514166 --- /dev/null +++ b/src/finn/transformation/fpgadataflow/insert_hook.py @@ -0,0 +1,93 @@ +import numpy as np +from onnx import TensorProto +from onnx import helper as oh + +from finn.custom_op.registry import getCustomOp +from finn.transformation.base import Transformation +from finn.util.fpgadataflow import is_fpgadataflow_node + + +def _is_hook_node(node): + if node.op_type in ["checksum"]: + return True + else: + return False + + +def _suitable_node(node): + if node is not None: + if is_fpgadataflow_node(node) is True: + if _is_hook_node(node) is False: + return True + else: + return False + else: + return False + else: + return False + + +class InsertHook(Transformation): + """Inserting hook layer after each layer that has the node attribute + 'output_hook' specified""" + + def __init__(self): + super().__init__() + + def apply(self, model): + list_supported_hooks = ["checksum"] + graph = model.graph + node_ind = -1 + graph_modified = False + for n in graph.node: + node_ind += 1 + if _suitable_node(n): + for output_name in n.output: + consumers = model.find_consumers(output_name) + if consumers == []: + continue + assert len(consumers) == 1, ( + n.name + + ": HLS node with fan-out higher than 1 cannot be stitched" + ) + consumer = consumers[0] + if _suitable_node(consumer) is True: + n0 = getCustomOp(n) + n0_hook = n0.get_nodeattr("output_hook") + if n0_hook in list_supported_hooks: + if n0_hook == "checksum": + n0_normal_oshape = n0.get_normal_output_shape() + n0_folded_oshape = n0.get_folded_output_shape() + n0_odt = n0.get_output_datatype() + items_per_word = n0.get_nodeattr("PE") + words_per_frame = np.prod(n0_folded_oshape[:-1]) + chk_otensor = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + n0_normal_oshape, + ) + chk_result = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + [1], + ) + chk_node = oh.make_node( + "checksum", + [output_name], + outputs=[chk_otensor.name, chk_result.name], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + words_per_frame=words_per_frame, + items_per_word=items_per_word, + inputDataType=str(n0_odt.name), + folded_shape=n0_folded_oshape, + ) + # insert dwc + graph.node.insert(node_ind + 1, chk_node) + + # set dwc output tensor as new input tensor of second node + for idx, inp in enumerate(consumer.input): + if inp == output_name: + consumer.input[idx] = chk_otensor.name + + return (model, graph_modified)