diff --git a/finn-rtllib/swg/swg_template_default.sv b/finn-rtllib/swg/swg_template_default.sv new file mode 100644 index 0000000000000000000000000000000000000000..97517438a0c261e4488b74a677a352f9dc51743b --- /dev/null +++ b/finn-rtllib/swg/swg_template_default.sv @@ -0,0 +1,351 @@ +/****************************************************************************** + * Copyright (C) 2022, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ +module $TOP_MODULE_NAME$_controller #( + int unsigned LOOP_H_ITERATIONS = $LOOP_H_ITERATIONS$, + int unsigned LOOP_W_ITERATIONS = $LOOP_W_ITERATIONS$, + int unsigned LOOP_KH_ITERATIONS = $LOOP_KH_ITERATIONS$, + int unsigned LOOP_KW_ITERATIONS = $LOOP_KW_ITERATIONS$, + int unsigned LOOP_SIMD_ITERATIONS = $LOOP_SIMD_ITERATIONS$, + + int unsigned INCR_BITWIDTH = $INCR_BITWIDTH$, + bit [INCR_BITWIDTH-1:0] ADDR_INCREMENT_MAP[6] = $ADDR_INCREMENT_MAP$, + + bit IS_DEPTHWISE = $IS_DEPTHWISE$ +)( + input logic clk, + input logic rst_n, + + input logic advance, + output logic [INCR_BITWIDTH-1:0] addr_incr, + output logic [INCR_BITWIDTH-1:0] tail_incr +); + + // state and counters + typedef enum logic [2:0] { + STATE_START, + STATE_LOOP_SIMD, + STATE_LOOP_KW, + STATE_LOOP_KH, + STATE_LOOP_W, + STATE_LOOP_H + } state_e; + state_e State = $INNERMOST_STATE$; + state_e state_next; + + logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] Counter_loop_h = LOOP_H_ITERATIONS-1; + logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] Counter_loop_w = LOOP_W_ITERATIONS-1; + logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] Counter_loop_kh = LOOP_KH_ITERATIONS-1; + logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] Counter_loop_kw = LOOP_KW_ITERATIONS-1; + logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] Counter_loop_simd = LOOP_SIMD_ITERATIONS-1; + + assign addr_incr = ADDR_INCREMENT_MAP[State]; + + // combinational logic for tail_incr generation + uwire tail_incr_inner_condition = IS_DEPTHWISE? (Counter_loop_kh >= 0) : 0; + always_comb begin : blkTail + if (tail_incr_inner_condition) + tail_incr = 1; + else if (Counter_loop_w >= 0) + tail_incr = $TAIL_INCR_W$; + else if (Counter_loop_h >= 0) + tail_incr = $TAIL_INCR_H$; + else + tail_incr = $TAIL_INCR_LAST$; + end + + // combinational next state logic + always_comb begin : blkState + state_next = State; + if(State != $INNERMOST_STATE$) state_next = $INNERMOST_STATE$; + else begin + if(Counter_loop_simd < 0) begin + state_next = + (Counter_loop_kw >= 0)? STATE_LOOP_KW : + (Counter_loop_kh >= 0)? STATE_LOOP_KH : + (Counter_loop_w >= 0)? STATE_LOOP_W : + (Counter_loop_h >= 0)? STATE_LOOP_H : + /* else */ STATE_START; + end + end + end : blkState + + // sequential logic + always_ff @ (posedge clk) begin + if(!rst_n) begin + State <= $INNERMOST_STATE$; + Counter_loop_h <= LOOP_H_ITERATIONS-1; + Counter_loop_w <= LOOP_W_ITERATIONS-1; + Counter_loop_kh <= LOOP_KH_ITERATIONS-1; + Counter_loop_kw <= LOOP_KW_ITERATIONS-1; + Counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; + end + else if(advance) begin + State <= state_next; + if (State == $INNERMOST_STATE$) begin + if(Counter_loop_simd >= 0) Counter_loop_simd <= Counter_loop_simd-1; + else begin + Counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; + if(Counter_loop_kw >= 0) Counter_loop_kw <= Counter_loop_kw-1; + else begin + Counter_loop_kw <= LOOP_KW_ITERATIONS-1; + if(Counter_loop_kh >= 0) Counter_loop_kh <= Counter_loop_kh-1; + else begin + Counter_loop_kh <= LOOP_KH_ITERATIONS-1; + if(Counter_loop_w >= 0) Counter_loop_w <= Counter_loop_w-1; + else begin + Counter_loop_w <= LOOP_W_ITERATIONS-1; + if(Counter_loop_h >= 0) Counter_loop_h <= Counter_loop_h-1; + else Counter_loop_h <= LOOP_H_ITERATIONS-1; + end + end + end + end + end + end + end + +endmodule : $TOP_MODULE_NAME$_controller + +module $TOP_MODULE_NAME$_cyclic_buffer_addressable #( + int unsigned WIDTH, + int unsigned DEPTH +)( + input logic clk, + input logic rst_n, + + input logic write_enable, + input logic [$clog2(DEPTH)-1:0] write_addr, + input logic [WIDTH-1:0] data_in, + + input logic read_enable, + input logic [$clog2(DEPTH)-1:0] read_addr, // absolute (!) read address of cyclic buffer + output logic [WIDTH-1:0] data_out +); + + $RAM_STYLE$ logic [WIDTH-1:0] Ram[DEPTH]; + logic [WIDTH-1:0] Out = 'x; + always_ff @(posedge clk) begin + if (read_enable) Out <= Ram[read_addr]; + if (write_enable) Ram[write_addr] <= data_in; + end + assign data_out = Out; + +endmodule : $TOP_MODULE_NAME$_cyclic_buffer_addressable + +module $TOP_MODULE_NAME$_impl #( + int BIT_WIDTH, + int SIMD, + int MMV_IN, + int MMV_OUT, + int LAST_READ_ELEM = $LAST_READ_ELEM$, + int LAST_WRITE_ELEM = $LAST_WRITE_ELEM$, + int BUF_ELEM_TOTAL = $BUF_ELEM_TOTAL$, + int ELEM_PER_WINDOW = $ELEM_PER_WINDOW$, + int INCR_BITWIDTH = $INCR_BITWIDTH$ +)( + input logic ap_clk, + input logic ap_rst_n, + + input logic in0_V_V_TVALID, + output logic in0_V_V_TREADY, + input logic [BIT_WIDTH * SIMD * MMV_IN-1:0] in0_V_V_TDATA, + + output logic out_V_V_TVALID, + input logic out_V_V_TREADY, + output logic [BIT_WIDTH * SIMD * MMV_OUT-1:0] out_V_V_TDATA +); + // derived Constants + localparam int unsigned BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; + localparam int unsigned BUF_OUT_ELEM_WIDTH = BIT_WIDTH * SIMD; + localparam int unsigned BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; + + // main buffer instantiation + uwire [BUF_IN_WIDTH -1:0] window_buffer_in; + uwire [BUF_OUT_WIDTH-1:0] window_buffer_out; + uwire window_buffer_write_enable; + uwire window_buffer_read_enable; + uwire [$clog2(BUF_ELEM_TOTAL)-1:0] window_buffer_write_addr; + uwire [$clog2(BUF_ELEM_TOTAL)-1:0] window_buffer_read_addr; + $TOP_MODULE_NAME$_cyclic_buffer_addressable #( + .WIDTH(BUF_IN_WIDTH), + .DEPTH(BUF_ELEM_TOTAL) + ) window_buffer_inst ( + .clk(ap_clk), + .rst_n(ap_rst_n), + + .write_enable(window_buffer_write_enable), + .write_addr(window_buffer_write_addr), + .data_in(window_buffer_in), + + .read_enable(window_buffer_read_enable), + .read_addr(window_buffer_read_addr), + .data_out(window_buffer_out) + ); + + //controller instantiation + uwire advance_controller; + uwire signed [INCR_BITWIDTH-1:0] addr_incr; + uwire [INCR_BITWIDTH-1:0] tail_incr; + $TOP_MODULE_NAME$_controller controller_inst ( + .clk(ap_clk), + .rst_n(ap_rst_n), + .advance(advance_controller), + .addr_incr(addr_incr), + .tail_incr(tail_incr) + ); + + // Counters/address registers + // Add a sign bit even to (most) unsigned counters and Window_buffer_read_addr_reg, + // so we can use automatic sign extension and simplify calculations w/ signed increment. + // Alternatively, we could manually sign-extend and shave off a bit here or there. + logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] Newest_buffered_elem = -1; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] Current_elem = 0; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] First_elem_next_window = 0; + logic [$clog2(ELEM_PER_WINDOW) -1:0] Position_in_window = 0; + logic [$clog2(BUF_ELEM_TOTAL)+1 -1:0] Window_buffer_read_addr_reg = 0; + logic [$clog2(BUF_ELEM_TOTAL)-1:0] Window_buffer_write_addr_reg = 0; + + // Control signals/registers + uwire read_cmd = + !reading_done && ( // if there is still an input element left to read + Fetching_done || ( // if fetching is done (e.g. for skipped rows at FM end due to stride) + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(First_elem_next_window) && + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(Current_elem) + ) // (over-)write to buffer if oldest buffered element will no longer be needed + ); + uwire read_ok = read_cmd && in0_V_V_TVALID; + uwire reading_done = Newest_buffered_elem == LAST_READ_ELEM; + + uwire fetch_cmd = !($signed(Current_elem) > Newest_buffered_elem) && !write_blocked && !Fetching_done; + logic Fetching_done = 0; + + logic Write_cmd = 0; + logic Writing_done = 0; + uwire write_ok = Write_cmd && out_V_V_TREADY; + uwire write_blocked = Write_cmd && !out_V_V_TREADY;; + + //assign buffer control + assign window_buffer_write_addr = Window_buffer_write_addr_reg; + assign window_buffer_read_addr = Window_buffer_read_addr_reg; + assign window_buffer_write_enable = read_ok; + assign window_buffer_read_enable = fetch_cmd; + assign advance_controller = fetch_cmd; + + //assign I/O ports + assign window_buffer_in = in0_V_V_TDATA; + assign out_V_V_TDATA = window_buffer_out; + assign in0_V_V_TREADY = ap_rst_n && read_ok; //only asserted if data is available and we can store it (allowed) + assign out_V_V_TVALID = ap_rst_n && Write_cmd; //only asserted if we have data available and it has not been read yet (don't wait for READY from sink) + + //main process for advancing counters + always_ff @(posedge ap_clk) begin + if(!ap_rst_n) begin + Newest_buffered_elem <= -1; + Current_elem <= 0; + First_elem_next_window <= 0; + Position_in_window <= 0; + Window_buffer_read_addr_reg <= 0; + Window_buffer_write_addr_reg <= 0; + Fetching_done <= 0; + Write_cmd <= 0; + Writing_done <= 0; + end + else begin + if (read_ok) begin + Window_buffer_write_addr_reg <= (Window_buffer_write_addr_reg == BUF_ELEM_TOTAL-1)? 0 : Window_buffer_write_addr_reg + 1; + Newest_buffered_elem <= Newest_buffered_elem+1; + + if (Newest_buffered_elem == LAST_READ_ELEM-1) begin + Window_buffer_write_addr_reg <= 0; + end + //check if this is the last read cycle (reading_done will be true afterwards) + if ((Newest_buffered_elem == LAST_READ_ELEM-1) && Writing_done) begin + //start processing of next FM if writing is done already (possible due to unused input elements at the tail end) + //todo: allow for read overlapping between feature maps (i.e., reading first elements from next FM while still writing last window of current FM) + Newest_buffered_elem <= -1; + Current_elem <= 0; + Window_buffer_read_addr_reg <= 0; + First_elem_next_window <= 0; + Writing_done <= 0; + Fetching_done <= 0; + end + end + + if (fetch_cmd) begin + //count up to track which element index is about to be read from the buffer, and where it is located within the buffer + //use increment value calculated by controller + + // absolute buffer address wrap-around + automatic logic signed [$clog2(BUF_ELEM_TOTAL)+1:0] ra = $signed(Window_buffer_read_addr_reg) + $signed(addr_incr); + automatic logic signed [$clog2(BUF_ELEM_TOTAL+1):0] ra_correct = + (ra >= BUF_ELEM_TOTAL)? -BUF_ELEM_TOTAL : + (ra < 0)? BUF_ELEM_TOTAL : 0; + Window_buffer_read_addr_reg <= ra + ra_correct; + + //keep track where we are within a window + Position_in_window <= (Position_in_window != ELEM_PER_WINDOW - 1)? Position_in_window+1 : 0; + + //update first element of next window to allow buffer overwrite up until that point + if (Position_in_window == 0) + First_elem_next_window <= First_elem_next_window + tail_incr; + + //check if this is the last write cycle (Writing_done will be true afterwards) + if (Current_elem == LAST_WRITE_ELEM) + Fetching_done <= 1; + else + Current_elem <= $signed(Current_elem) + addr_incr; + + // determine if prefetched data will be outstanding in the next cycle + // if we fetch in this cycle -> yes + // if we do not fetch nor write -> do not change + // if we do not fetch but write successfully-> clear outstanding data + Write_cmd <= fetch_cmd; + end + + if (write_ok) + Write_cmd <= fetch_cmd; + + if (write_ok && Fetching_done) begin + //check if this is the last write cycle (Writing_done will be true afterwards) + if (reading_done || (read_ok && (Newest_buffered_elem == LAST_READ_ELEM - 1))) begin + //start processing of next FM if reading is done already, or completes in the same cycle + Newest_buffered_elem <= -1; + Current_elem <= 0; + Window_buffer_read_addr_reg <= 0; + First_elem_next_window <= 0; + Fetching_done <= 0; + end else + Writing_done <= 1; + end + end + end + +endmodule : $TOP_MODULE_NAME$_impl diff --git a/finn-rtllib/swg/swg_template_wrapper.v b/finn-rtllib/swg/swg_template_wrapper.v new file mode 100644 index 0000000000000000000000000000000000000000..0cc3579a255fddaf1a470d440b9e8ac245abe486 --- /dev/null +++ b/finn-rtllib/swg/swg_template_wrapper.v @@ -0,0 +1,75 @@ +/****************************************************************************** + * Copyright (C) 2022, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ +`timescale 1 ns / 1 ps + +module $TOP_MODULE_NAME$ ( +(* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out_V" *) +input ap_clk, +(* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out_V" *) +input ap_rst_n, +input [BUF_IN_WIDTH-1:0] in0_V_TDATA, +input in0_V_TVALID, +output in0_V_TREADY, +output [BUF_OUT_WIDTH-1:0] out_V_TDATA, +output out_V_TVALID, +input out_V_TREADY +); + +// top-level parameters (set via code-generation) +parameter BIT_WIDTH = $BIT_WIDTH$; +parameter SIMD = $SIMD$; +parameter MMV_IN = $MMV_IN$; +parameter MMV_OUT = $MMV_OUT$; + +// derived constants +parameter BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; +parameter BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; + +$TOP_MODULE_NAME$_impl +#( + .BIT_WIDTH(BIT_WIDTH), + .SIMD(SIMD), + .MMV_IN(MMV_IN), + .MMV_OUT(MMV_OUT) +) +impl +( + .ap_clk(ap_clk), + .ap_rst_n(ap_rst_n), + .in0_V_V_TDATA(in0_V_TDATA), + .in0_V_V_TVALID(in0_V_TVALID), + .in0_V_V_TREADY(in0_V_TREADY), + .out_V_V_TDATA(out_V_TDATA), + .out_V_V_TVALID(out_V_TVALID), + .out_V_V_TREADY(out_V_TREADY) +); + +endmodule //TOP_MODULE_NAME diff --git a/src/finn/builder/build_dataflow_config.py b/src/finn/builder/build_dataflow_config.py index 92263bd82ce291833c6868847876ac7e3b68e6f8..e16711f63b954707bc7ad9050dd7627ca1ce99c1 100644 --- a/src/finn/builder/build_dataflow_config.py +++ b/src/finn/builder/build_dataflow_config.py @@ -258,6 +258,10 @@ class DataflowBuildConfig: #: Which memory mode will be used for compute layers default_mem_mode: Optional[ComputeEngineMemMode] = ComputeEngineMemMode.DECOUPLED + #: Force inference of RTL ConvolutionInputGenerator over HLS implementation + #: If set to False, falls back to the default behavior of InferConvInpGen() + force_rtl_conv_inp_gen: Optional[bool] = False + #: Which Vitis platform will be used. #: Only relevant when `shell_flow_type = ShellFlowType.VITIS_ALVEO` #: e.g. "xilinx_u250_xdma_201830_2" diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index 59f77650da5c3c3f9db0ea65e2288544b376bec3..e77f17d7c27f4be08aa6725e5803a1ea566c9443 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -302,7 +302,10 @@ def step_convert_to_hls(model: ModelWrapper, cfg: DataflowBuildConfig): # needed for convolutions -- TODO always exec? need_conv = len(model.get_nodes_by_op_type("Im2Col")) > 0 if need_conv: - model = model.transform(to_hls.InferConvInpGen()) + if cfg.force_rtl_conv_inp_gen: + model = model.transform(to_hls.InferConvInpGen(use_rtl_variant=True)) + else: + model = model.transform(to_hls.InferConvInpGen()) model = model.transform(to_hls.InferStreamingMaxPool()) model = model.transform(RemoveCNVtoFCFlatten()) # get rid of Tranpose -> Tranpose identity seq diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index 2c7c86c64ea1279cb18cf8342aa20fb2792bdaf5..49577fbf1b5774e33b63674242aed69c1d12a53e 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -36,6 +36,9 @@ from finn.custom_op.fpgadataflow.convolutioninputgenerator import ( from finn.custom_op.fpgadataflow.convolutioninputgenerator1d import ( ConvolutionInputGenerator1D, ) +from finn.custom_op.fpgadataflow.convolutioninputgenerator_rtl import ( + ConvolutionInputGenerator_rtl, +) from finn.custom_op.fpgadataflow.downsampler import DownSampler from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch from finn.custom_op.fpgadataflow.fmpadding_batch import FMPadding_Batch @@ -67,6 +70,7 @@ custom_op["StreamingMaxPool_Batch"] = StreamingMaxPool_Batch custom_op["MatrixVectorActivation"] = MatrixVectorActivation custom_op["ConvolutionInputGenerator"] = ConvolutionInputGenerator custom_op["ConvolutionInputGenerator1D"] = ConvolutionInputGenerator1D +custom_op["ConvolutionInputGenerator_rtl"] = ConvolutionInputGenerator_rtl custom_op["TLastMarker"] = TLastMarker custom_op["StreamingDataWidthConverter_Batch"] = StreamingDataWidthConverter_Batch custom_op["StreamingFIFO"] = StreamingFIFO diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py new file mode 100755 index 0000000000000000000000000000000000000000..399b36e15021af6f449df3e9ba2acdc699a27647 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py @@ -0,0 +1,834 @@ +# Copyright (C) 2022, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import math +import numpy as np +import os +from math import copysign +from qonnx.core.datatype import DataType +from qonnx.custom_op.general import im2col +from qonnx.custom_op.general.im2col import compute_conv_output_dim + +from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp +from finn.util.basic import get_rtlsim_trace_depth, make_build_dir +from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy + +try: + from pyverilator import PyVerilator +except ModuleNotFoundError: + PyVerilator = None + +# RTL Convolution Input Generator / Sliding Window Generator (SWG) +# Matches and extends the functionality of all ConvolutionInputGenerator_* functions +# in finn-hlslib by generating HDL code for two different implementation styles: +# - Addressable cyclic buffer: to be used when out_width <= in_width +# - Parallel registers + line buffers: to be used when out_width > in_width +# Supports non-square, 1D, strided, dilated, and depthwise convolutions. +# Note: the actual data layout produced is different for depthwise and non-depthwise: +# * non-depthwise SWG: (1, OFMDim_H, OFMDim_W, K_H, K_W, IFMChannels/SIMD, SIMD) +# * depthwise SWG: (1, OFMDim_H, OFMDim_W, IFMChannels/SIMD, K_H, K_W, SIMD) + +# NOTE: "Parallel" implementation style not yet implemented in this version! + + +class ConvolutionInputGenerator_rtl(HLSCustomOp): + """Class that does not correspond to one of the finn-hlslib ConvolutionInputGenerator + (sliding window) function variants. Generates an RTL ConvolutionInputGenerator + implementation based on (System-)Verilog templates, defined in finn-rtllib/swg.""" + + def __init__(self, onnx_node): + super().__init__(onnx_node) + + def get_nodeattr_types(self): + my_attrs = { + "ConvKernelDim": ("ints", True, []), # [H, W] = [Y, X] + "IFMChannels": ("i", True, 0), + "IFMDim": ("ints", True, []), # [H, W] = [Y, X] + "OFMDim": ("ints", True, []), # [H, W] = [Y, X] + "SIMD": ("i", True, 0), + # additional parallelization parameter - not yet implemented + "M": ("i", False, 1), + # alternative implementation style - not yet implemented + "parallel_window": ("i", False, 0, {0}), + "Stride": ("ints", True, []), # [H, W] = [Y, X] + "Dilation": ("ints", True, []), # [H, W] = [Y, X] + # FINN DataTypes for inputs, weights, outputs + "inputDataType": ("s", True, ""), + "outputDataType": ("s", True, ""), + "depthwise": ("i", False, 0, {0, 1}), + # FPGA resource type for ConvolutionInputGenerator input buffer + # auto -- let Vivado decide + # block -- use BRAM + # distributed -- use LUTRAM + # ultra -- use URAM + "ram_style": ( + "s", + False, + "auto", + {"auto", "block", "distributed", "ultra"}, + ), + # attribute to save top module name - not user configurable + "gen_top_module": ("s", False, ""), + } + my_attrs.update(super().get_nodeattr_types()) + return my_attrs + + def get_normal_input_shape(self): + ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + ishape = (1, ifm_dim_h, ifm_dim_w, ifm_ch) + return ishape + + def get_folded_input_shape(self): + ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + simd = self.get_nodeattr("SIMD") + assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" + wf = int(ifm_ch / simd) + folded_ishape = (1, ifm_dim_h, ifm_dim_w, wf, simd) + return folded_ishape + + def get_normal_output_shape(self): + k_h, k_w = self.get_nodeattr("ConvKernelDim") + ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + stride_h, stride_w = self.get_nodeattr("Stride") + dilation_h, dilation_w = self.get_nodeattr("Dilation") + pad = 0 + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, pad, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, pad, dilation_w) + oshape = (1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch) + return oshape + + def get_folded_output_shape(self): + k_h, k_w = self.get_nodeattr("ConvKernelDim") + ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + stride_h, stride_w = self.get_nodeattr("Stride") + dilation_h, dilation_w = self.get_nodeattr("Dilation") + simd = self.get_nodeattr("SIMD") + pad = 0 + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, pad, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, pad, dilation_w) + assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" + if self.get_nodeattr("parallel_window"): + wf = int((ifm_ch) // simd) + folded_oshape = (1, ofm_dim_h, ofm_dim_w, wf, k_h * k_w * simd) + else: + wf = int((k_h * k_w * ifm_ch) // simd) + folded_oshape = (1, ofm_dim_h, ofm_dim_w, wf, simd) + return folded_oshape + + def make_shape_compatible_op(self, model): + exp_ishape = self.get_normal_input_shape() + oshape = self.get_normal_output_shape() + ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) + assert ishape == exp_ishape, "Unexpect input shape for ConvInpGen." + return super().make_const_shape_op(oshape) + + def infer_node_datatype(self, model): + node = self.onnx_node + # data type stays the same + dtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], dtype) + + def verify_node(self): + pass + + def get_input_datatype(self): + """Returns FINN DataType of input.""" + return DataType[self.get_nodeattr("inputDataType")] + + def get_output_datatype(self): + """Returns FINN DataType of output.""" + return DataType[self.get_nodeattr("outputDataType")] + + def get_instream_width(self): + ibits = self.get_input_datatype().bitwidth() + simd = self.get_nodeattr("SIMD") + ifm_ch = self.get_nodeattr("IFMChannels") + assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" + in_width = simd * ibits + return in_width + + def get_outstream_width(self): + if self.get_nodeattr("parallel_window"): + # feed all window pixels in parallel + k_h, k_w = self.get_nodeattr("ConvKernelDim") + return self.get_instream_width() * k_h * k_w + else: + # if parallel variant not in use: same width for output and input stream + return self.get_instream_width() + + def get_number_input_values(self): + folded_ishape = self.get_folded_input_shape() + num_input_elems = np.prod(folded_ishape[:-1]) + return num_input_elems + + def get_number_output_values(self): + folded_oshape = self.get_folded_output_shape() + num_output_elems = np.prod(folded_oshape[:-1]) + return num_output_elems + + def get_1d_conv_attrs_normalized(self): + # normalize FM dimensions so that: + # [H, W] = [Y, X] = [1, D] or [D, 1] are always mapped to [1, D]. + # The dummy ('1') dimension is the Y-dimension. + ifm_ch = self.get_nodeattr("IFMChannels") + k = self.get_nodeattr("ConvKernelDim") + ifm_dim = self.get_nodeattr("IFMDim") + ofm_dim = self.get_nodeattr("OFMDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + + if ifm_dim[1] == 1: + ifm_dim = ifm_dim[::-1] + ofm_dim = ofm_dim[::-1] + k = k[::-1] + stride = stride[::-1] + dilation = dilation[::-1] + + return (ifm_ch, ifm_dim, ofm_dim, k, stride, dilation) + + def get_buffer_depth(self): + ifm_ch = self.get_nodeattr("IFMChannels") + k = self.get_nodeattr("ConvKernelDim") + ifm_dim = self.get_nodeattr("IFMDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + simd = self.get_nodeattr("SIMD") + + k_h, k_w = k + h, w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + mmv_in = 1 + mmv_out = 1 + channel_factor = int(ifm_ch / simd) + + impl_style = self.select_impl_style() + if impl_style == "default": + # compute minimal buffer length (assuming it holds 1 complete window) + buffer_min_size = ( + (k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1 + ) * channel_factor + + # add additional buffer space in case of stride > 1 + # this minimizes cycle count as it allows an earlier pre-load of inputs + buffer_depth = ( + buffer_min_size + + max( + 0, + ((stride_w - 1) - (int(mmv_out * k_h * k_w / mmv_in))) + * channel_factor, + ) + + max( + 0, + ((stride_h - 1) * w - (int(mmv_out * k_h * k_w / mmv_in))) + * channel_factor, + ) + ) + else: + buffer_depth = 0 + raise Exception("Requested impl. style not implemented") + return buffer_depth + + def get_exp_cycles(self): + simd = self.get_nodeattr("SIMD") + ifm_ch = self.get_nodeattr("IFMChannels") + k = self.get_nodeattr("ConvKernelDim") + ifm_dim = self.get_nodeattr("IFMDim") + ofm_dim = self.get_nodeattr("OFMDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + depthwise = self.get_nodeattr("depthwise") + ifm_dim_h, ifm_dim_w = ifm_dim + ofm_dim_h, ofm_dim_w = ofm_dim + k_h, k_w = k + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + + channel_factor = int(ifm_ch / simd) + + if ifm_dim_h == 1 or ifm_dim_w == 1: + # 1D case + ( + ifm_ch, + [ifm_dim_h, ifm_dim_w], + [ofm_dim_h, ofm_dim_w], + [k_h, k_w], + [stride_h, stride_w], + [dilation_h, dilation_w], + ) = self.get_1d_conv_attrs_normalized() + + if depthwise: + exp_cycles = ( + +ofm_dim_w * k_w * channel_factor + + channel_factor * (k_w - 1) * (stride_w - 1) + - (k_w - 1) + + 2 + ) + else: + exp_cycles = ofm_dim_w * k_w * channel_factor + 2 + else: + # 2D case + buffer_min_size = ( + (k_h - 1) * dilation_h * ifm_dim_w + (k_w - 1) * dilation_w + 1 + ) * channel_factor + cycles_write_block = ofm_dim_w * k_w * k_h * channel_factor + cycles_read_block = stride_w * ifm_dim_w * channel_factor + max_cycles = max(cycles_write_block, cycles_read_block) + if depthwise: + max_cycles += ofm_dim_w * (stride_w - 1) * (channel_factor - 1) + exp_cycles = buffer_min_size + ofm_dim_h * max_cycles # initial buffering + if depthwise: + exp_cycles += (stride_h - 1) * ifm_dim_w * channel_factor + + return int(exp_cycles) + + def bram_estimation(self): + simd = self.get_nodeattr("SIMD") + ram_style = self.get_nodeattr("ram_style") + + # NOTE: Actual BRAM usage might be lower in some cases. + # This does not account for the exact Vivado behavior yet. + buffer_width = simd * self.get_input_datatype().bitwidth() + buffer_depth = self.get_buffer_depth() + if ram_style == "block" or ram_style == "auto": + if buffer_depth <= 512: + ram_width = 36 + elif buffer_depth <= 1024: + ram_width = 18 + elif buffer_depth <= 2048: + ram_width = 9 + elif buffer_depth <= 4096: + ram_width = 4 + elif buffer_depth <= 8192: + ram_width = 2 + else: + ram_width = 1 + + ram_cascade_depth = math.ceil(buffer_depth / 16384) + ram_cascade_width = math.ceil(buffer_width / ram_width) + cascade_savings = 0 + if buffer_depth > 16384: + remainder_depth = buffer_depth % 16384 + if remainder_depth <= 512: + remainder_width = 36 + elif remainder_depth <= 1024: + remainder_width = 18 + elif remainder_depth <= 2048: + remainder_width = 9 + elif remainder_depth <= 4096: + remainder_width = 4 + elif remainder_depth <= 8192: + remainder_width = 2 + else: + remainder_width = 1 + + remainder_cascade_width = math.ceil(buffer_width / remainder_width) + cascade_savings = ram_cascade_width - remainder_cascade_width + + return int(ram_cascade_depth * ram_cascade_width - cascade_savings) + else: + return 0 + + def lut_estimation(self): + simd = self.get_nodeattr("SIMD") + ram_style = self.get_nodeattr("ram_style") + buffer_width = simd * self.get_input_datatype().bitwidth() + buffer_depth = self.get_buffer_depth() + if ram_style == "distributed": + ram_luts = int(buffer_width * math.ceil(buffer_depth / 38)) + else: + ram_luts = 0 + return 300 + ram_luts + + def uram_estimation(self): + simd = self.get_nodeattr("SIMD") + ram_style = self.get_nodeattr("ram_style") + buffer_width = simd * self.get_input_datatype().bitwidth() + buffer_depth = self.get_buffer_depth() + + if ram_style == "ultra": + ram_depth = 4096 + ram_width = 72 + ram_cascade_depth = math.ceil(buffer_depth / ram_depth) + ram_cascade_width = math.ceil(buffer_width / ram_width) + return int(ram_cascade_depth * ram_cascade_width) + else: + return 0 + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + node = self.onnx_node + exp_ishape = self.get_normal_input_shape() + exp_oshape = self.get_normal_output_shape() + folded_ishape = self.get_folded_input_shape() + + if mode == "cppsim": + raise Exception( + "cppsim not possible for RTL SWG, please set exec_mode to rtlsim" + ) + elif mode == "rtlsim": + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following value ("cppsim", "rtlsim")""".format( + mode + ) + ) + + inp = context[node.input[0]] + assert str(inp.dtype) == "float32", "Input datatype is not float32" + assert ( + inp.shape == exp_ishape + ), """Input shape doesn't match expected shape (1, ifm_dim, ifm_dim, ifm_ch).""" + if self.get_input_datatype() == DataType["BIPOLAR"]: + # store bipolar activations as binary + inp = (inp + 1) / 2 + export_idt = DataType["BINARY"] + else: + export_idt = self.get_input_datatype() + + # reshape input into folded form + inp = inp.reshape(folded_ishape) + # make copy before saving array + reshaped_input = inp.copy() + np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input) + + sim = self.get_rtlsim() + nbits = self.get_instream_width() + rtlsim_inp = npy_to_rtlsim_input( + "{}/input_0.npy".format(code_gen_dir), export_idt, nbits + ) + super().reset_rtlsim(sim) + super().toggle_clk(sim) + rtlsim_output = self.rtlsim(sim, rtlsim_inp) + odt = export_idt + target_bits = odt.bitwidth() + packed_bits = self.get_outstream_width() + out_npy_path = "{}/output.npy".format(code_gen_dir) + out_shape = self.get_folded_output_shape() + rtlsim_output_to_npy( + rtlsim_output, out_npy_path, odt, out_shape, packed_bits, target_bits + ) + # load and reshape output + output = np.load(out_npy_path) + output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape) + context[node.output[0]] = output + + # binary -> bipolar if needed + if self.get_output_datatype() == DataType["BIPOLAR"]: + out = context[node.output[0]] + out = 2 * out - 1 + context[node.output[0]] = out + assert ( + context[node.output[0]].shape == exp_oshape + ), """Output + shape doesn't match expected shape (1, ofm_dim_h, ofm_dim_w, k_h*k_w*ifm_ch).""" + + def prepare_codegen_default(self): + # Default implementation style for MMV_out = 1: addressable cyclic buffer + # Computing incremental addressing scheme directly.. + template_path = ( + os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_default.sv" + ) + code_gen_dict = {} + + ifm_ch = self.get_nodeattr("IFMChannels") + k = self.get_nodeattr("ConvKernelDim") + ifm_dim = self.get_nodeattr("IFMDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + depthwise = self.get_nodeattr("depthwise") + simd = self.get_nodeattr("SIMD") + + k_h, k_w = k + h, w = ifm_dim + pad = [0, 0, 0, 0] # padding happens in separate padding node for now + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + pad_h = pad[0] + pad[2] + pad_w = pad[1] + pad[3] + out_dim_h = im2col.compute_conv_output_dim(h, k_h, stride_h, pad_h, dilation_h) + out_dim_w = im2col.compute_conv_output_dim(w, k_w, stride_w, pad_w, dilation_w) + mmv_in = 1 + mmv_out = 1 + channel_factor = int(ifm_ch / simd) + + # compute minimal buffer length (assuming it holds 1 complete window) + buffer_min_size = ( + (k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1 + ) * channel_factor + + buffer_actual_size = self.get_buffer_depth() + code_gen_dict["$BUF_ELEM_TOTAL$"] = [str(buffer_actual_size)] + + # compute some intermediate values, e.g., kernel "width" = k_w incl. dilation + # or cols/rows that are skipped due to imperfect stride<->dim combination + kernel_width = (k_w - 1) * dilation_w + 1 + kernel_height = (k_h - 1) * dilation_h + 1 + skip_columns = w % (kernel_width + (out_dim_w - 1) * stride_w) + skip_rows = h % (kernel_height + (out_dim_h - 1) * stride_h) + + # compute address increment values for 5-loop nest + addr_incr_end_simd = 1 + addr_incr_end_window_elem = (dilation_w - 1) * channel_factor + 1 + addr_incr_end_window_row = ( + ((w - kernel_width) * channel_factor) # remaining line + + ((dilation_h - 1) * w * channel_factor) # skip lines + + 1 # wrap-around of minimally sized buffer + ) + addr_incr_end_window = -buffer_min_size + stride_w * channel_factor + 1 + addr_incr_end_row = ( + -buffer_min_size + + ((skip_columns + kernel_width) * channel_factor) # remaining line + + ((stride_h - 1) * w * channel_factor) # skip lines + + 1 + ) + + # re-use same controller structure -> re-assign address increments + if depthwise: + addr_incr_end_window_elem = dilation_w * channel_factor + addr_incr_end_window_row = ( + channel_factor + + (w - kernel_width) * channel_factor + + (dilation_h - 1) * w * channel_factor + ) + addr_incr_end_simd = -buffer_min_size + (channel_factor + 1) + + # sanity check + assert not ( + abs(addr_incr_end_window) > buffer_actual_size + ), "ERROR: W increment > buffer size, wrap logic doesn't account for this" + assert not ( + abs(addr_incr_end_row) > buffer_actual_size + ), "ERROR: H increment > buffer size, wrap logic doesn't account for this" + + # set certain threshold indices to detect when reading/writing finishes + code_gen_dict["$LAST_READ_ELEM$"] = [str(h * w * channel_factor - 1)] + code_gen_dict["$LAST_WRITE_ELEM$"] = [ + str(((h - skip_rows - 1) * w + (w - skip_columns)) * channel_factor - 1) + ] + + # default controller loop structure: # iterations (counters) map directly + loop_h_iterations = out_dim_h + loop_w_iterations = out_dim_w + loop_kh_iterations = k_h + loop_kw_iterations = k_w + loop_simd_iterations = channel_factor + + if depthwise and channel_factor > 1: + # re-arrange existing controller loop structure for depthwise convolutions + loop_kh_iterations = channel_factor + loop_kw_iterations = k_h + loop_simd_iterations = k_w + addr_incr_end_simd_ = addr_incr_end_simd + addr_incr_end_simd = addr_incr_end_window_elem + addr_incr_end_window_elem = addr_incr_end_window_row + addr_incr_end_window_row = addr_incr_end_simd_ + elem_per_window = k_h * k_w + + tail_incr_w = addr_incr_end_window + buffer_min_size - channel_factor + tail_incr_h = addr_incr_end_row + buffer_min_size - channel_factor + tail_incr_last_window = buffer_min_size - 1 + code_gen_dict["$IS_DEPTHWISE$"] = ["1"] + else: + # depthwise output format is equivalent to non-depthwise if SIMD=C + elem_per_window = k_h * k_w * channel_factor + + tail_incr_w = addr_incr_end_window + buffer_min_size - 1 + tail_incr_h = addr_incr_end_row + buffer_min_size - 1 + tail_incr_last_window = buffer_min_size - 1 + code_gen_dict["$IS_DEPTHWISE$"] = ["0"] + + code_gen_dict["$TAIL_INCR_W$"] = [str(tail_incr_w)] + code_gen_dict["$TAIL_INCR_H$"] = [str(tail_incr_h)] + code_gen_dict["$TAIL_INCR_LAST$"] = [str(tail_incr_last_window)] + + # support SIMD = IFMChannels and k_w = 1 cases + # for k = [k_h, k_w] = [1, k_w], no adjustment is needed + # for k = [k_h, k_w] = [1, 1], do not use this impl. style (mmv_out=K=1) + # innermost loop is executed at least once -> adjust if needed + if loop_simd_iterations == 1: + # skip innermost SIMD loop completely + if loop_kw_iterations == 1: + # skip innermost KW loop completely + code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_KH"] + loop_kh_iterations -= 1 # -1 because state is initial state + else: + code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_KW"] + loop_kw_iterations -= 1 # -1 because state is initial state + else: + code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_SIMD"] + loop_simd_iterations -= 1 # -1 because state is initial state + + code_gen_dict["$LOOP_H_ITERATIONS$"] = [str(loop_h_iterations - 1)] + code_gen_dict["$LOOP_W_ITERATIONS$"] = [str(loop_w_iterations - 1)] + code_gen_dict["$LOOP_KH_ITERATIONS$"] = [str(loop_kh_iterations - 1)] + code_gen_dict["$LOOP_KW_ITERATIONS$"] = [str(loop_kw_iterations - 1)] + code_gen_dict["$LOOP_SIMD_ITERATIONS$"] = [str(loop_simd_iterations - 1)] + + incr_bitwidth = 1 + math.ceil( + math.log2( + max( + abs(addr_incr_end_simd) + 1, + abs(addr_incr_end_window_elem) + 1, + abs(addr_incr_end_window_row) + 1, + abs(addr_incr_end_window) + 1, + abs(addr_incr_end_row) + 1, + abs(tail_incr_w) + 1, + abs(tail_incr_h) + 1, + abs(tail_incr_last_window) + 1, + ) + ) + ) + code_gen_dict["$INCR_BITWIDTH$"] = [str(incr_bitwidth)] + code_gen_dict["$ADDR_INCREMENT_MAP$"] = [ + "'{{ {}'d0, {}'d{}, {}'d{}, {}'d{}, {}'d{}, {}'d{}}}".format( + incr_bitwidth, + int(copysign(incr_bitwidth, addr_incr_end_simd)), + abs(addr_incr_end_simd), + int(copysign(incr_bitwidth, addr_incr_end_window_elem)), + abs(addr_incr_end_window_elem), + int(copysign(incr_bitwidth, addr_incr_end_window_row)), + abs(addr_incr_end_window_row), + int(copysign(incr_bitwidth, addr_incr_end_window)), + abs(addr_incr_end_window), + int(copysign(incr_bitwidth, addr_incr_end_row)), + abs(addr_incr_end_row), + ) + ] + + code_gen_dict["$ELEM_PER_WINDOW$"] = [str(elem_per_window)] + code_gen_dict["$SIMD$"] = [str(simd)] + code_gen_dict["$MMV_IN$"] = [str(mmv_in)] + code_gen_dict["$MMV_OUT$"] = [str(mmv_out)] + + return template_path, code_gen_dict + + def select_impl_style(self): + simd = self.get_nodeattr("SIMD") + M = self.get_nodeattr("M") + ifm_ch = self.get_nodeattr("IFMChannels") + ifm_dim = self.get_nodeattr("IFMDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + k = self.get_nodeattr("ConvKernelDim") + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + k_h, k_w = k + kernel_width = (k_w - 1) * dilation_w + 1 # incl. dilation + kernel_height = (k_h - 1) * dilation_h + 1 # incl. dilation + + # check for valid configuration + assert ( + kernel_height <= ifm_dim_h + and kernel_width <= ifm_dim_w + and stride_h <= ifm_dim_h + and stride_w <= ifm_dim_w + ), "Illegal conv configuration: kernel or stride > FM dimension" + + # init folding config + if self.get_nodeattr("parallel_window"): + # mmv_in = M * 1 + mmv_out = M * k_h * k_w + assert ( + ifm_ch == simd + ), "Constraint violated: SIMD must be equal to IFMChannels" + else: + # mmv_in = 1 + mmv_out = 1 + assert ( + ifm_ch % simd == 0 + ), "Constraint violated: SIMD must divide IFMChannels" + + # choose implementation style + if mmv_out > 1 or (k_h == 1 and k_w == 1): + impl_style = "parallel" + assert ( + ifm_ch == simd + ), "Constraint violated: SIMD must be equal to IFMChannels" + else: + impl_style = "default" + + assert ( + impl_style == "default" + ), "ERROR: Parallel window mode not yet implemented" + return impl_style + + def generate_hdl(self): + impl_style = self.select_impl_style() + + # prepare code generation by filling out dictionaries + if impl_style == "default": + template_path, code_gen_dict = self.prepare_codegen_default() + else: + raise Exception("Requested impl. style not implemented") + + # add general parameters to dictionary + code_gen_dict["$TOP_MODULE_NAME$"] = [self.get_verilog_top_module_name()] + # save top module name so we can refer to it after this node has been renamed + # (e.g. by GiveUniqueNodeNames(prefix) during MakeZynqProject) + self.set_nodeattr("gen_top_module", self.get_verilog_top_module_name()) + code_gen_dict["$BIT_WIDTH$"] = [str(self.get_input_datatype().bitwidth())] + ram_style = self.get_nodeattr("ram_style") + if ram_style == "auto": + code_gen_dict["$RAM_STYLE$"] = [""] + else: + code_gen_dict["$RAM_STYLE$"] = ['(* ram_style = "{}" *)'.format(ram_style)] + + # apply code generation to templates + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + with open(template_path, "r") as f: + template = f.read() + with open( + os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_wrapper.v", "r" + ) as f: + template_wrapper = f.read() + for key in code_gen_dict: + # transform list into long string separated by '\n' + code_gen_line = "\n".join(code_gen_dict[key]) + template = template.replace(key, code_gen_line) + template_wrapper = template_wrapper.replace(key, code_gen_line) + with open( + os.path.join( + code_gen_dir, self.get_nodeattr("gen_top_module") + "_impl.sv" + ), + "w", + ) as f: + f.write(template) + with open( + os.path.join( + code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v" + ), + "w", + ) as f: + f.write(template_wrapper) + + # set ipgen_path and ip_path so that HLS-Synth transformation + # and stich_ip transformation do not complain + self.set_nodeattr("ipgen_path", code_gen_dir) + self.set_nodeattr("ip_path", code_gen_dir) + + def prepare_rtlsim(self): + """Creates a Verilator emulation library for the RTL code generated + for this node, sets the rtlsim_so attribute to its path and returns + a PyVerilator wrapper around it.""" + # Modified to use generated (System-)Verilog instead of HLS output products + + if PyVerilator is None: + raise ImportError("Installation of PyVerilator is required.") + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + verilog_paths = [code_gen_dir] + verilog_files = [ + self.get_nodeattr("gen_top_module") + "_wrapper.v", + self.get_nodeattr("gen_top_module") + "_impl.sv", + ] + + # build the Verilator emu library + sim = PyVerilator.build( + verilog_files, + build_dir=make_build_dir("pyverilator_" + self.onnx_node.name + "_"), + verilog_path=verilog_paths, + trace_depth=get_rtlsim_trace_depth(), + top_module_name=self.get_verilog_top_module_name(), + ) + # save generated lib filename in attribute + self.set_nodeattr("rtlsim_so", sim.lib._name) + return sim + + def code_generation_ipi(self): + """Constructs and returns the TCL for node instantiation in Vivado IPI.""" + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + + cmd = [ + "add_files -norecurse %s" + % ( + os.path.join( + code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v" + ) + ), + "add_files -norecurse %s" + % ( + os.path.join( + code_gen_dir, self.get_nodeattr("gen_top_module") + "_impl.sv" + ) + ), + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name), + ] + + return cmd + + def code_generation_ipgen(self, model, fpgapart, clk): + """Normally: Generates C++ code and tcl script for IP generation. + Here: Generates (System-)Verilog code for IP generation.""" + self.generate_hdl() + + def ipgen_singlenode_code(self): + """Normally: Builds the bash script for IP generation.""" + pass + + def code_generation_cppsim(self, model): + """Normally: Generates C++ code for simulation (cppsim).""" + pass + + def compile_singlenode_code(self): + pass + + def global_includes(self): + pass + + def defines(self, var): + pass + + def read_npy_data(self): + pass + + def strm_decl(self): + pass + + def docompute(self): + pass + + def dataoutstrm(self): + pass + + def save_as_npy(self): + pass + + def blackboxfunction(self): + pass + + def pragmas(self): + pass diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index 429bc34ffc59b5d98bb559f36ac557de4dbba92f..540c217cbca8c47243a080ac493f19bd1c72abc8 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -48,6 +48,10 @@ from finn.transformation.fpgadataflow.minimize_accumulator_width import ( class InferConvInpGen(Transformation): """Convert Im2Col layers to ConvolutionInputGenerator layers.""" + def __init__(self, use_rtl_variant=False): + super().__init__() + self.use_rtl_variant = use_rtl_variant + def apply(self, model): graph = model.graph node_ind = 0 @@ -128,105 +132,141 @@ class InferConvInpGen(Transformation): ) graph.node.insert(node_ind, padding_node) - # Ensure that only supported HLS nodes are inserted + is_kernel_pointwise = k_h == 1 and k_w == 1 is_square_image = ConvInpGen_idim_h == ConvInpGen_idim_w is_square_kernel = k_h == k_w - is_kernel_pointwise = k_h == 1 and k_w == 1 is_equal_stride = stride_h == stride_w is_1d_convolution = (k_h == 1 and k_w > 1 and ifm_dim_h == 1) or ( k_h > 1 and k_w == 1 and ifm_dim_w == 1 ) - if (stride_h > 1 or stride_w > 1) and is_kernel_pointwise: - assert is_square_image, ( - "%s : DownSampler currently only supports square input images." - % n.name - ) - assert is_equal_stride, ( - """%s : DownSampler currently only supports equal stride value - along different axes.""" - % n.name - ) - ConvInpGen_idim = ConvInpGen_idim_h - stride = stride_h - # create DownSampler node + # Ensure that RTL variant is not inserted for unsupported configuration + is_rtl_variant_compatible = True + if is_kernel_pointwise: + is_rtl_variant_compatible = False + if self.use_rtl_variant: + warnings.warn( + """%s : RTL ConvInpGen requested for unsupported + configuration. Falling back to HLS implementation.""" + % n.name + ) + + if self.use_rtl_variant and is_rtl_variant_compatible: ConvInpGen_node = helper.make_node( - "DownSampler", + "ConvolutionInputGenerator_rtl", [ConvInpGen_input], [i2c_output], domain="finn.custom_op.fpgadataflow", backend="fpgadataflow", - ImgDim=ConvInpGen_idim, - NumChannels=ifm_ch, + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ConvInpGen_idim_h, ConvInpGen_idim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], SIMD=ifm_ch, - Stride=stride, + M=1, + parallel_window=0, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], inputDataType=dt.name, - name="DownSampler_" + n.name, + outputDataType=dt.name, + depthwise=depthwise, + name="ConvolutionInputGenerator_rtl_" + n.name, ) graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node) else: - # create equivalent ConvolutionInputGenerator node - if ( - is_square_image and is_square_kernel - ): # square images and square kernels - assert is_equal_stride, ( - """%s: Non-equal strides along different axes is not supported - for (non-)square convolutions""" + # Ensure that only supported HLS nodes are inserted + if (stride_h > 1 or stride_w > 1) and is_kernel_pointwise: + assert is_square_image, ( + """%s : DownSampler currently only supports square + input images.""" % n.name ) - assert dilation_h == 1 and dilation_w == 1, ( - """%s: Dilation value != 1 is not supported - for square convolutions""" + assert is_equal_stride, ( + """%s : DownSampler currently only supports equal stride + value along different axes.""" % n.name ) + ConvInpGen_idim = ConvInpGen_idim_h + stride = stride_h + # create DownSampler node ConvInpGen_node = helper.make_node( - "ConvolutionInputGenerator", + "DownSampler", [ConvInpGen_input], [i2c_output], domain="finn.custom_op.fpgadataflow", backend="fpgadataflow", - ConvKernelDim=[k_h, k_w], - IFMChannels=ifm_ch, - IFMDim=[ConvInpGen_idim_h, ConvInpGen_idim_w], - OFMDim=[ofm_dim_h, ofm_dim_w], + ImgDim=ConvInpGen_idim, + NumChannels=ifm_ch, SIMD=ifm_ch, - Stride=[stride_h, stride_w], - Dilation=[dilation_h, dilation_w], + Stride=stride, inputDataType=dt.name, - outputDataType=dt.name, - depthwise=depthwise, - name="ConvolutionInputGenerator_" + n.name, - ) - else: # 1D images and/or kernels - assert is_1d_convolution, ( - "%s: ConvolutionInputGenerator1D works only for 1D convs" - % n.name + name="DownSampler_" + n.name, ) - if dilation_h > 1 or dilation_w > 1: - assert depthwise == 1, ( - """%s: Dilation value > 1 is only supported for - 1D depthwise separable convolutions""" + graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node) + else: + # create equivalent ConvolutionInputGenerator node + if ( + is_square_image and is_square_kernel + ): # square images and square kernels + assert is_equal_stride, ( + """%s: Non-equal strides along different axes is not supported + for (non-)square convolutions""" % n.name ) - ConvInpGen_node = helper.make_node( - "ConvolutionInputGenerator1D", - [ConvInpGen_input], - [i2c_output], - domain="finn.custom_op.fpgadataflow", - backend="fpgadataflow", - ConvKernelDim=[k_h, k_w], - IFMChannels=ifm_ch, - IFMDim=[ConvInpGen_idim_h, ConvInpGen_idim_w], - OFMDim=[ofm_dim_h, ofm_dim_w], - SIMD=ifm_ch, - Stride=[stride_h, stride_w], - Dilation=[dilation_h, dilation_w], - inputDataType=dt.name, - outputDataType=dt.name, - depthwise=depthwise, - name="ConvolutionInputGenerator1D_" + n.name, - ) - graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node) + assert dilation_h == 1 and dilation_w == 1, ( + """%s: Dilation value != 1 is not supported + for square convolutions""" + % n.name + ) + ConvInpGen_node = helper.make_node( + "ConvolutionInputGenerator", + [ConvInpGen_input], + [i2c_output], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ConvInpGen_idim_h, ConvInpGen_idim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], + SIMD=ifm_ch, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], + inputDataType=dt.name, + outputDataType=dt.name, + depthwise=depthwise, + name="ConvolutionInputGenerator_" + n.name, + ) + else: # 1D images and/or kernels + assert is_1d_convolution, ( + """%s: ConvolutionInputGenerator1D works only + for 1D convs""" + % n.name + ) + if dilation_h > 1 or dilation_w > 1: + assert depthwise == 1, ( + """%s: Dilation value > 1 is only supported for + 1D depthwise separable convolutions""" + % n.name + ) + ConvInpGen_node = helper.make_node( + "ConvolutionInputGenerator1D", + [ConvInpGen_input], + [i2c_output], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ConvInpGen_idim_h, ConvInpGen_idim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], + SIMD=ifm_ch, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], + inputDataType=dt.name, + outputDataType=dt.name, + depthwise=depthwise, + name="ConvolutionInputGenerator1D_" + n.name, + ) + graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node) # remove old nodes graph.node.remove(n) graph_modified = True diff --git a/src/finn/transformation/fpgadataflow/set_folding.py b/src/finn/transformation/fpgadataflow/set_folding.py index 23943084ab99d6ab880a69975e0b4a49756905a7..e24e24f1f8ebb2873c81617884cd333311d8aea9 100644 --- a/src/finn/transformation/fpgadataflow/set_folding.py +++ b/src/finn/transformation/fpgadataflow/set_folding.py @@ -109,6 +109,7 @@ class SetFolding(Transformation): "FMPadding_Batch", "ConvolutionInputGenerator", "ConvolutionInputGenerator1D", + "ConvolutionInputGenerator_rtl", ] # these ops are preceded by depthwise SWG and have special behavior, # as explained in the SetFolding docstring @@ -171,10 +172,7 @@ class SetFolding(Transformation): "Expected SWU on DW op input, found " + swu_node.op_type ) elif op_type in simd_ops: - if op_type in [ - "ConvolutionInputGenerator", - "ConvolutionInputGenerator1D", - ]: + if op_type.startswith("ConvolutionInputGenerator"): depthwise = node_inst.get_nodeattr("depthwise") if depthwise == 0: max_simd = node_inst.get_nodeattr("IFMChannels") diff --git a/tests/fpgadataflow/test_convert_to_hls_1d_conv_layer.py b/tests/fpgadataflow/test_convert_to_hls_1d_conv_layer.py index 5bbaefac2d3e5f800fbb9471df6469235271c2f3..7b3e20616410f54e4718290baec9a510a0d49c0d 100644 --- a/tests/fpgadataflow/test_convert_to_hls_1d_conv_layer.py +++ b/tests/fpgadataflow/test_convert_to_hls_1d_conv_layer.py @@ -66,11 +66,12 @@ from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode ], ) @pytest.mark.parametrize("depthwise", [False, True]) +@pytest.mark.parametrize("use_rtl_swg", [False, True]) @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) @pytest.mark.fpgadataflow @pytest.mark.slow @pytest.mark.vivado -def test_convert_to_hls_1d_conv_layer(conv_config, depthwise, exec_mode): +def test_convert_to_hls_1d_conv_layer(conv_config, depthwise, use_rtl_swg, exec_mode): pad, kernel_size, stride, dilation = conv_config np.random.seed(0) idt = DataType["UINT4"] @@ -84,6 +85,9 @@ def test_convert_to_hls_1d_conv_layer(conv_config, depthwise, exec_mode): pad_h = pad[0] + pad[2] pad_w = pad[1] + pad[3] + if use_rtl_swg and exec_mode == "cppsim": + pytest.skip("cppsim not supported for RTL SWG") + if depthwise is True: group = out_chn = in_chn conv_param_shape = [out_chn, 1, k_h, k_w] @@ -139,7 +143,7 @@ def test_convert_to_hls_1d_conv_layer(conv_config, depthwise, exec_mode): model = model.transform(InferDataTypes()) new_model = model.transform(LowerConvsToMatMul()) - new_model = new_model.transform(to_hls.InferConvInpGen()) + new_model = new_model.transform(to_hls.InferConvInpGen(use_rtl_variant=use_rtl_swg)) if depthwise is True: new_model = new_model.transform(to_hls.InferVectorVectorActivation()) else: diff --git a/tests/fpgadataflow/test_convert_to_hls_conv_layer.py b/tests/fpgadataflow/test_convert_to_hls_conv_layer.py index 55dc77cafb898ead28a7cbb9641e0b40db276919..8c9f110c315089ec03354863bf2213963197217a 100644 --- a/tests/fpgadataflow/test_convert_to_hls_conv_layer.py +++ b/tests/fpgadataflow/test_convert_to_hls_conv_layer.py @@ -57,11 +57,12 @@ from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode "conv_config", [(1, 2, 0), (1, 3, 0), (3, 2, 1), (3, 1, 0), (3, 1, 1), (5, 2, 1)] ) @pytest.mark.parametrize("depthwise", [False, True]) +@pytest.mark.parametrize("use_rtl_swg", [False, True]) @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) @pytest.mark.fpgadataflow @pytest.mark.slow @pytest.mark.vivado -def test_convert_to_hls_conv_layer(conv_config, depthwise, exec_mode): +def test_convert_to_hls_conv_layer(conv_config, depthwise, use_rtl_swg, exec_mode): kernel_size, stride, pad = conv_config np.random.seed(0) idt = DataType["UINT4"] @@ -69,6 +70,12 @@ def test_convert_to_hls_conv_layer(conv_config, depthwise, exec_mode): in_feature_dim = 7 in_chn = 16 + if use_rtl_swg and exec_mode == "cppsim": + pytest.skip("cppsim not supported for RTL SWG") + + if use_rtl_swg and kernel_size == 1: + pytest.skip("1x1 kernel not supported by current RTL SWG") + if depthwise is True: group = out_chn = in_chn conv_param_shape = [out_chn, 1, kernel_size, kernel_size] @@ -122,7 +129,7 @@ def test_convert_to_hls_conv_layer(conv_config, depthwise, exec_mode): model = model.transform(InferDataTypes()) new_model = model.transform(LowerConvsToMatMul()) - new_model = new_model.transform(to_hls.InferConvInpGen()) + new_model = new_model.transform(to_hls.InferConvInpGen(use_rtl_variant=use_rtl_swg)) if depthwise is True: new_model = new_model.transform(to_hls.InferVectorVectorActivation()) else: @@ -156,6 +163,7 @@ def test_convert_to_hls_conv_layer(conv_config, depthwise, exec_mode): x = gen_finn_dt_tensor(idt, input_shape) inp_dict = {model.graph.input[0].name: x} assert oxe.compare_execution(model, new_model, inp_dict) + if kernel_size == 1 and stride > 1 and pad == 0: assert new_model.graph.node[1].op_type == "DownSampler" if exec_mode == "rtlsim": diff --git a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py new file mode 100755 index 0000000000000000000000000000000000000000..007360a5fd0b74ee49d54c84f332061dd5f3a114 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py @@ -0,0 +1,260 @@ +# Copyright (C) 2022, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from onnx import TensorProto, helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.general.im2col import compute_conv_output_dim +from qonnx.transformation.general import GiveUniqueNodeNames +from qonnx.util.basic import gen_finn_dt_tensor + +import finn.core.onnx_exec as oxe +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode + + +def make_single_im2col_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, stride, dilation, idt): + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h, ofm_dim_w = ofm_dim + + odt = idt + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_dim_h, ifm_dim_w, ifm_ch] + ) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch] + ) + + im2col_node = helper.make_node( + "Im2Col", + ["inp"], + ["outp"], + domain="finn.custom_op.general", + stride=[stride_h, stride_w], + kernel_size=[k_h, k_w], + input_shape=str((1, ifm_dim_h, ifm_dim_w, ifm_ch)), + dilations=[dilation_h, dilation_w], + pad_amount=[0, 0, 0, 0], + pad_value=0, + ) + graph = helper.make_graph( + nodes=[im2col_node], name="im2col_graph", inputs=[inp], outputs=[outp] + ) + + model = helper.make_model(graph, producer_name="im2col-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + + return model + + +def make_single_slidingwindow_modelwrapper( + k, ifm_ch, ifm_dim, ofm_dim, simd, m, parallel_window, stride, dilation, idt, dw=0 +): + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h, ofm_dim_w = ofm_dim + + odt = idt + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_dim_h, ifm_dim_w, ifm_ch] + ) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch] + ) + + SlidingWindow_node = helper.make_node( + "ConvolutionInputGenerator_rtl", + ["inp"], + ["outp"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ifm_dim_h, ifm_dim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], + SIMD=simd, + M=m, + parallel_window=parallel_window, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], + inputDataType=idt.name, + outputDataType=odt.name, + depthwise=dw, + ) + graph = helper.make_graph( + nodes=[SlidingWindow_node], + name="slidingwindow_graph", + inputs=[inp], + outputs=[outp], + ) + + model = helper.make_model(graph, producer_name="slidingwindow-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + + return model + + +def prepare_inputs(input_tensor): + return {"inp": input_tensor} + + +# input datatype +@pytest.mark.parametrize("idt", [DataType["UINT4"]]) +# kernel size +@pytest.mark.parametrize("k", [[2, 2], [3, 3], [1, 3]]) +# input dimension +@pytest.mark.parametrize("ifm_dim", [[24, 24], [15, 6], [13, 13], [1, 14]]) +# input channels +@pytest.mark.parametrize("ifm_ch", [6]) +# Stride +@pytest.mark.parametrize("stride", [[1, 1], [2, 2]]) +# Dilation +@pytest.mark.parametrize("dilation", [[1, 1], [2, 2]]) +# depthwise +@pytest.mark.parametrize("dw", [0, 1]) +# input channel parallelism ("SIMD") +@pytest.mark.parametrize("simd", [1, 2, 3, 6]) +# parallel_window enable (MMV_out = M*K) +@pytest.mark.parametrize("parallel_window", [0]) +# in/out MMV ("M") +@pytest.mark.parametrize("m", [1]) +# Flip dimensions +@pytest.mark.parametrize("flip", [False]) +@pytest.mark.slow +@pytest.mark.vivado +@pytest.mark.fpgadataflow +def test_fpgadataflow_slidingwindow_rtl( + idt, k, ifm_dim, ifm_ch, stride, dilation, dw, simd, m, parallel_window, flip +): + if flip: + if ( + ifm_dim[0] == ifm_dim[1] + and k[0] == k[1] + and stride[0] == stride[1] + and dilation[0] == dilation[1] + ): + pytest.skip("Dimension flip would have no effect") + k = k[::-1] + ifm_dim = ifm_dim[::-1] + stride = stride[::-1] + dilation = dilation[::-1] + + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + + kernel_width = (k_w - 1) * dilation_w + 1 # incl. dilation + kernel_height = (k_h - 1) * dilation_h + 1 # incl. dilation + + if simd > ifm_ch: + pytest.skip("SIMD cannot be larger than number of input channels") + if ifm_ch % simd != 0: + pytest.skip("SIMD must divide number of input channels") + if kernel_height > ifm_dim_h or stride_h > ifm_dim_h: + pytest.skip( + "Illegal convolution configuration: kernel or stride > FM dimension" + ) + if kernel_width > ifm_dim_w or stride_w > ifm_dim_w: + pytest.skip( + "Illegal convolution configuration: kernel or stride > FM dimension" + ) + if (k_h == 1 and (stride_h != 1 or dilation_h != 1)) or ( + k_w == 1 and (stride_w != 1 or dilation_w != 1) + ): + pytest.skip( + """Illegal convolution configuration: + stride or dilation defined for unitary kernel dim""" + ) + if k_h == 1 and k_w == 1 and simd != ifm_ch: + pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)") + if parallel_window and simd != ifm_ch: + pytest.skip("Parallel window requires SIMD=C") + + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) + ofm_dim = [ofm_dim_h, ofm_dim_w] + + x = gen_finn_dt_tensor(idt, (1, ifm_dim_h, ifm_dim_w, ifm_ch)) + model = make_single_slidingwindow_modelwrapper( + k=k, + ifm_ch=ifm_ch, + ifm_dim=ifm_dim, + ofm_dim=ofm_dim, + simd=simd, + m=m, + parallel_window=parallel_window, + stride=stride, + dilation=dilation, + idt=idt, + dw=dw, + ) + + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xc7z020clg400-1", 5)) + model = model.transform(PrepareRTLSim()) + + # prepare input data + input_dict = prepare_inputs(x) + # execute model + y_produced = oxe.execute_onnx(model, input_dict)["outp"] + golden = make_single_im2col_modelwrapper( + k=k, + ifm_ch=ifm_ch, + ifm_dim=ifm_dim, + ofm_dim=ofm_dim, + stride=stride, + dilation=dilation, + idt=idt, + ) + y_expected = oxe.execute_onnx(golden, input_dict)["outp"] + + if dw == 0: + assert (y_produced == y_expected).all() + else: + y_expected = y_expected.reshape( + 1, ofm_dim_h, ofm_dim_w, k_h * k_w, ifm_ch // simd, simd + ) + y_expected = y_expected.transpose(0, 1, 2, 4, 3, 5) + y_expected = y_expected.reshape(1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w) + assert (y_produced == y_expected).all()