Skip to content
Snippets Groups Projects
Commit d6feac78 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Pad] add FMPadding_rtl to op registry, fixes to its CustomOp

parent e2cc28d3
No related branches found
No related tags found
No related merge requests found
......@@ -43,6 +43,7 @@ from finn.custom_op.fpgadataflow.downsampler import DownSampler
from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
from finn.custom_op.fpgadataflow.eltwise import StreamingEltwise
from finn.custom_op.fpgadataflow.fmpadding_batch import FMPadding_Batch
from finn.custom_op.fpgadataflow.fmpadding_rtl import FMPadding_rtl
from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
from finn.custom_op.fpgadataflow.iodma import IODMA
from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
......@@ -91,3 +92,4 @@ custom_op["Lookup"] = Lookup
custom_op["StreamingConcat"] = StreamingConcat
custom_op["CheckSum"] = CheckSum
custom_op["StreamingEltwise"] = StreamingEltwise
custom_op["FMPadding_rtl"] = FMPadding_rtl
......@@ -32,6 +32,7 @@ import os
import shutil
import warnings
from qonnx.core.datatype import DataType
from qonnx.util.basic import roundup_to_integer_multiple
from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp
from finn.util.basic import get_rtlsim_trace_depth, make_build_dir
......@@ -226,7 +227,6 @@ class FMPadding_rtl(HLSCustomOp):
)
super().reset_rtlsim(sim)
super().toggle_clk(sim)
assert False, "Need register config here until default values are implemented"
rtlsim_output = self.rtlsim(sim, rtlsim_inp)
odt = export_idt
target_bits = odt.bitwidth()
......@@ -246,25 +246,60 @@ class FMPadding_rtl(HLSCustomOp):
), """Output shape doesn't match expected shape
(1, OutputDim_H, OutputDim_W, NumChannels)."""
def generate_hdl(self):
dimY, dimX = self.get_nodeattr("ImgDim")
padT, padL, padB, padR = self.get_nodeattr("Padding")
chans = self.get_nodeattr("NumChannels")
simd = self.get_nodeattr("SIMD")
idt = self.get_nodeattr("inputDataType")
def get_template_values(self, ifm_dims, pads, chans, simd, idt):
dimY, dimX = ifm_dims
padT, padL, padB, padR = pads
y_counter_bits = int(math.log2(padT + dimY + padB))
x_counter_bits = int(math.log2(padL + dimX + padR))
topname = self.get_verilog_top_module_name()
rtlsrc = os.environ["FINN_ROOT"] + "/finn-rtllib/fmpadding/hdl"
template_path = rtlsrc + "/fmpadding_template.sv"
stream_bits = idt.bitwidth() * simd
stream_bits = int(roundup_to_integer_multiple(stream_bits, 8))
code_gen_dict = {
"XCOUNTER_BITS": x_counter_bits,
"YCOUNTER_BITS": y_counter_bits,
"NUM_CHANNELS": chans,
"SIMD": simd,
"XCOUNTER_BITS": int(x_counter_bits),
"YCOUNTER_BITS": int(y_counter_bits),
"NUM_CHANNELS": int(chans),
"SIMD": int(simd),
"ELEM_BITS": idt.bitwidth(),
"TOP_MODULE_NAME": topname,
"INIT_XON": int(padL),
"INIT_XOFF": int(padL + dimX),
"INIT_XEND": int(padL + dimX + padR),
"INIT_YON": int(padT),
"INIT_YOFF": int(padT + dimY),
"INIT_YEND": int(padT + dimY + padB),
"STREAM_BITS": int(stream_bits),
}
return code_gen_dict
def get_dynamic_config(self, ifm_dims, pads):
"""Returns a configuration dict to re-configure FM dimension and
padding amounts during runtime."""
dims = self.get_nodeattr("ImgDim")
pads = self.get_nodeattr("Padding")
chans = self.get_nodeattr("NumChannels")
simd = self.get_nodeattr("SIMD")
idt = self.get_input_datatype()
code_gen_dict = self.get_template_values(dims, pads, chans, simd, idt)
config = {
"XON": (0, (code_gen_dict["INIT_XON"])),
"XOFF": (1, (code_gen_dict["INIT_XOFF"])),
"XEND": (2, (code_gen_dict["INIT_XEND"])),
"YON": (4, (code_gen_dict["INIT_YON"])),
"YOFF": (5, (code_gen_dict["INIT_YOFF"])),
"YEND": (6, (code_gen_dict["INIT_YEND"])),
}
return config
def generate_hdl(self):
rtlsrc = os.environ["FINN_ROOT"] + "/finn-rtllib/fmpadding/hdl"
template_path = rtlsrc + "/fmpadding_template.sv"
dims = self.get_nodeattr("ImgDim")
pads = self.get_nodeattr("Padding")
chans = self.get_nodeattr("NumChannels")
simd = self.get_nodeattr("SIMD")
idt = self.get_input_datatype()
code_gen_dict = self.get_template_values(dims, pads, chans, simd, idt)
# save top module name so we can refer to it after this node has been renamed
# (e.g. by GiveUniqueNodeNames(prefix) during MakeZynqProject)
self.set_nodeattr("gen_top_module", self.get_verilog_top_module_name())
......@@ -275,18 +310,17 @@ class FMPadding_rtl(HLSCustomOp):
template = f.read()
for key_name in code_gen_dict:
key = "$%s$" % key_name
# transform list into long string separated by '\n'
code_gen_line = "\n".join(code_gen_dict[key])
template = template.replace(key, code_gen_line)
template = template.replace(key, str(code_gen_dict[key_name]))
with open(
os.path.join(code_gen_dir, topname + ".sv"),
os.path.join(code_gen_dir, self.get_verilog_top_module_name() + ".sv"),
"w",
) as f:
f.write(template)
shutil.copyfile(rtlsrc + "/fmpadding_axi.sv", code_gen_dir)
shutil.copyfile(rtlsrc + "/fmpadding.sv", code_gen_dir)
sv_files = ["fmpadding_axi.sv", "fmpadding.sv", "axi2we.sv"]
for sv_file in sv_files:
shutil.copy(rtlsrc + "/" + sv_file, code_gen_dir)
# set ipgen_path and ip_path so that HLS-Synth transformation
# and stich_ip transformation do not complain
self.set_nodeattr("ipgen_path", code_gen_dir)
......@@ -306,6 +340,7 @@ class FMPadding_rtl(HLSCustomOp):
verilog_files = [
"fmpadding_axi.sv",
"fmpadding.sv",
"axi2we.sv",
self.get_nodeattr("gen_top_module") + ".sv",
]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment