Skip to content
Snippets Groups Projects
Commit 318d2b3d authored by auphelia's avatar auphelia
Browse files

[Transformation] Add insert hook transformation

parent f491f0a1
No related branches found
No related tags found
No related merge requests found
......@@ -206,14 +206,13 @@ class checksum(HLSCustomOp):
self.code_gen_dict["$GLOBALS$"] = ['#include "checksum.hpp"']
def defines(self, var):
items_per_word = self.get_nodeattr("items_per_word")
words_per_frame = self.get_nodeattr("words_per_frame")
word_size = self.get_instream_width()
my_defines = []
my_defines.append(
"#define WORDS_PER_FRAME {}".format(self.get_nodeattr("words_per_frame"))
)
my_defines.append(
"#define ITEMS_PER_WORD {}".format(self.get_nodeattr("items_per_word"))
)
my_defines.append("#define WORD_SIZE {}".format(self.get_instream_width()))
my_defines.append("#define WORDS_PER_FRAME {}".format(words_per_frame))
my_defines.append("#define ITEMS_PER_WORD {}".format(items_per_word))
my_defines.append("#define WORD_SIZE {}".format(word_size))
self.code_gen_dict["$DEFINES$"] = my_defines
def read_npy_data(self):
......
......@@ -112,6 +112,7 @@ class HLSCustomOp(CustomOp):
# input and output FIFO depths
"inFIFODepth": ("i", False, 2),
"outFIFODepth": ("i", False, 2),
"output_hook": ("s", False, ""),
# HLS version to be used for IP synthesis
"hls_version": ("s", False, "vitis_hls", {"vivado_hls", "vitis_hls"}),
}
......
import numpy as np
from onnx import TensorProto
from onnx import helper as oh
from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation
from finn.util.fpgadataflow import is_fpgadataflow_node
def _is_hook_node(node):
if node.op_type in ["checksum"]:
return True
else:
return False
def _suitable_node(node):
if node is not None:
if is_fpgadataflow_node(node) is True:
if _is_hook_node(node) is False:
return True
else:
return False
else:
return False
else:
return False
class InsertHook(Transformation):
"""Inserting hook layer after each layer that has the node attribute
'output_hook' specified"""
def __init__(self):
super().__init__()
def apply(self, model):
list_supported_hooks = ["checksum"]
graph = model.graph
node_ind = -1
graph_modified = False
for n in graph.node:
node_ind += 1
if _suitable_node(n):
for output_name in n.output:
consumers = model.find_consumers(output_name)
if consumers == []:
continue
assert len(consumers) == 1, (
n.name
+ ": HLS node with fan-out higher than 1 cannot be stitched"
)
consumer = consumers[0]
if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n0_hook = n0.get_nodeattr("output_hook")
if n0_hook in list_supported_hooks:
if n0_hook == "checksum":
n0_normal_oshape = n0.get_normal_output_shape()
n0_folded_oshape = n0.get_folded_output_shape()
n0_odt = n0.get_output_datatype()
items_per_word = n0.get_nodeattr("PE")
words_per_frame = np.prod(n0_folded_oshape[:-1])
chk_otensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_normal_oshape,
)
chk_result = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
[1],
)
chk_node = oh.make_node(
"checksum",
[output_name],
outputs=[chk_otensor.name, chk_result.name],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
words_per_frame=words_per_frame,
items_per_word=items_per_word,
inputDataType=str(n0_odt.name),
folded_shape=n0_folded_oshape,
)
# insert dwc
graph.node.insert(node_ind + 1, chk_node)
# set dwc output tensor as new input tensor of second node
for idx, inp in enumerate(consumer.input):
if inp == output_name:
consumer.input[idx] = chk_otensor.name
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment