diff --git a/finn-rtllib/swg/swg_template_parallel.sv b/finn-rtllib/swg/swg_template_parallel.sv new file mode 100644 index 0000000000000000000000000000000000000000..432c37476436b95823b28ff9937594945a26ed73 --- /dev/null +++ b/finn-rtllib/swg/swg_template_parallel.sv @@ -0,0 +1,406 @@ +/****************************************************************************** + * Copyright (C) 2022, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ +module $TOP_MODULE_NAME$_controller #( + int unsigned LOOP_H_ITERATIONS = $LOOP_H_ITERATIONS$, + int unsigned LOOP_W_ITERATIONS = $LOOP_W_ITERATIONS$, + int unsigned LOOP_KH_ITERATIONS = $LOOP_KH_ITERATIONS$, + int unsigned LOOP_KW_ITERATIONS = $LOOP_KW_ITERATIONS$, + int unsigned LOOP_SIMD_ITERATIONS = $LOOP_SIMD_ITERATIONS$, + + int unsigned INCR_BITWIDTH = $INCR_BITWIDTH$, + + bit IS_DEPTHWISE = $IS_DEPTHWISE$ +)( + input logic clk, + input logic rst_n, + + input logic advance, + output logic [INCR_BITWIDTH-1:0] addr_incr, + output logic [INCR_BITWIDTH-1:0] tail_incr +); + + // state and counters + typedef enum logic [2:0] { + STATE_START, + STATE_LOOP_SIMD, + STATE_LOOP_KW, + STATE_LOOP_KH, + STATE_LOOP_W, + STATE_LOOP_H + } state_e; + state_e State = $INNERMOST_STATE$; + state_e state_next; + + logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] Counter_loop_h = LOOP_H_ITERATIONS; + logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] Counter_loop_w = LOOP_W_ITERATIONS; + logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] Counter_loop_kh = LOOP_KH_ITERATIONS; + logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] Counter_loop_kw = LOOP_KW_ITERATIONS; + logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] Counter_loop_simd = LOOP_SIMD_ITERATIONS; + + // combinational logic for addr_incr generation + always_comb begin : blkHead + unique case (State) + 0 : addr_incr = 0; + 1 : addr_incr = $HEAD_INCR_SIMD$; + 2 : addr_incr = $HEAD_INCR_KW$; + 3 : addr_incr = $HEAD_INCR_KH$; + 4 : addr_incr = $HEAD_INCR_W$; + 5 : addr_incr = $HEAD_INCR_H$; + endcase + end + + // combinational logic for tail_incr generation + uwire tail_incr_inner_condition = IS_DEPTHWISE? (Counter_loop_kh >= 0) : 0; + assign tail_incr = + tail_incr_inner_condition? 1 : + Counter_loop_w >= 0? $TAIL_INCR_W$ : + Counter_loop_h >= 0? $TAIL_INCR_H$ : + /* else */ $TAIL_INCR_LAST$; + + // combinational next state logic + always_comb begin : blkState + state_next = State; + if(State != $INNERMOST_STATE$) state_next = $INNERMOST_STATE$; + else begin + if(Counter_loop_simd < 0) begin + state_next = + (Counter_loop_kw >= 0)? STATE_LOOP_KW : + (Counter_loop_kh >= 0)? STATE_LOOP_KH : + (Counter_loop_w >= 0)? STATE_LOOP_W : + (Counter_loop_h >= 0)? STATE_LOOP_H : + /* else */ STATE_START; + end + end + end : blkState + + // sequential logic + always_ff @ (posedge clk) begin + if(!rst_n) begin + State <= $INNERMOST_STATE$; + Counter_loop_h <= LOOP_H_ITERATIONS; + Counter_loop_w <= LOOP_W_ITERATIONS; + Counter_loop_kh <= LOOP_KH_ITERATIONS; + Counter_loop_kw <= LOOP_KW_ITERATIONS; + Counter_loop_simd <= LOOP_SIMD_ITERATIONS; + end + else if(advance) begin + State <= state_next; + if (State == $INNERMOST_STATE$) begin + if(Counter_loop_simd >= 0) Counter_loop_simd <= Counter_loop_simd-1; + else begin + Counter_loop_simd <= LOOP_SIMD_ITERATIONS; + if(Counter_loop_kw >= 0) Counter_loop_kw <= Counter_loop_kw-1; + else begin + Counter_loop_kw <= LOOP_KW_ITERATIONS; + if(Counter_loop_kh >= 0) Counter_loop_kh <= Counter_loop_kh-1; + else begin + Counter_loop_kh <= LOOP_KH_ITERATIONS; + if(Counter_loop_w >= 0) Counter_loop_w <= Counter_loop_w-1; + else begin + Counter_loop_w <= LOOP_W_ITERATIONS; + if(Counter_loop_h >= 0) Counter_loop_h <= Counter_loop_h-1; + else Counter_loop_h <= LOOP_H_ITERATIONS; + end + end + end + end + end + end + end + +endmodule : $TOP_MODULE_NAME$_controller + +module $TOP_MODULE_NAME$_reg_buffer +#( + parameter WIDTH = 1, + parameter DEPTH = 1 +) +( + CLK, + shift_enable, + shift_in, + shift_out, + data_out +); + +input CLK, shift_enable; +input [WIDTH-1:0] shift_in; +output [WIDTH-1:0] shift_out; +output [WIDTH*DEPTH-1:0] data_out; + +reg [WIDTH-1:0] data [DEPTH-1:0]; + +assign shift_out = data[DEPTH-1]; + +for (genvar e=0; e<DEPTH; e=e+1) + assign data_out[e*WIDTH +: WIDTH] = data[e]; + +always @ (posedge CLK) begin + if (shift_enable) begin + for (integer i=DEPTH-1; i>0; i=i-1) + data[i] <= data[i-1]; + data[0] <= shift_in; + end +end +endmodule : $TOP_MODULE_NAME$_reg_buffer + +module $TOP_MODULE_NAME$_ram_buffer +#( + parameter WIDTH = 1, + parameter DEPTH = 1 +) +( + CLK, + RST, + shift_enable, + shift_in, + shift_out +); + +input CLK, RST, shift_enable; +input [WIDTH-1:0] shift_in; +output [WIDTH-1:0] shift_out; + +reg [WIDTH-1:0] out_reg; +assign shift_out = out_reg; + +integer addr_w, addr_r; //TODO: minimize width + simplify + +$RAM_STYLE$ reg [WIDTH-1:0] ram [DEPTH-1:0]; + +always @(posedge CLK) begin + if (RST == 1'b0) begin + addr_w <= 0; + addr_r <= 1; + end else begin + if (shift_enable) begin + ram[addr_w] <= shift_in; + out_reg <= ram[addr_r]; + + if (addr_w == DEPTH-1) + addr_w <= 0; + else + addr_w <= addr_w + 1; + + if (addr_r == DEPTH-1) + addr_r <= 0; + else + addr_r <= addr_r + 1; + end + end +end +endmodule : $TOP_MODULE_NAME$_ram_buffer + +module $TOP_MODULE_NAME$_wb +#( + parameter IN_WIDTH = 1, //bit-width*C*MMV_in + parameter OUT_ELEM_WIDTH = 1, //bit-width*C + parameter OUT_WIDTH = 1, //bit-width*C*MMV_out + parameter BUFFER_ELEM_TOTAL = 1 +) +( + CLK, + RST, + data_in, + shift_enable, + data_out +); + +input CLK, RST; +input [IN_WIDTH-1:0] data_in; +input shift_enable; +output [OUT_WIDTH-1:0] data_out; + +$GENERATE_REG_FIFOS$ + +$GENERATE_BRAM_FIFOS$ + +//Fixed interconnect between linear buffers +$GENERATE_BUFFER_CONNECTION$ + +//Fixed REG FIFO <-> output mapping +$GENERATE_OUTPUT_MAPPING$ + + +endmodule : $TOP_MODULE_NAME$_wb + +module $TOP_MODULE_NAME$_impl #( + int BIT_WIDTH, + int SIMD, + int MMV_IN, + int MMV_OUT, + int LAST_READ_ELEM = $LAST_READ_ELEM$, + int FIRST_WRITE_ELEM = $FIRST_WRITE_ELEM$, + int LAST_WRITE_ELEM = $LAST_WRITE_ELEM$, + int BUF_ELEM_TOTAL = $BUF_ELEM_TOTAL$, + int INCR_BITWIDTH = $INCR_BITWIDTH$ +)( + input logic ap_clk, + input logic ap_rst_n, + + input logic in0_V_V_TVALID, + output logic in0_V_V_TREADY, + input logic [BIT_WIDTH * SIMD * MMV_IN-1:0] in0_V_V_TDATA, + + output logic out_V_V_TVALID, + input logic out_V_V_TREADY, + output logic [BIT_WIDTH * SIMD * MMV_OUT-1:0] out_V_V_TDATA +); + // derived constants + localparam int unsigned BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; + localparam int unsigned BUF_OUT_ELEM_WIDTH = BIT_WIDTH * SIMD; + localparam int unsigned BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; + + //main buffer instantiation + uwire [BUF_IN_WIDTH -1:0] window_buffer_in; + uwire [BUF_OUT_WIDTH-1:0] window_buffer_out; + uwire window_buffer_shift_enable; + $TOP_MODULE_NAME$_wb + #( + .IN_WIDTH(BUF_IN_WIDTH), + .OUT_ELEM_WIDTH(BUF_OUT_ELEM_WIDTH), + .OUT_WIDTH(BUF_OUT_WIDTH), + .BUFFER_ELEM_TOTAL(BUF_ELEM_TOTAL) + ) + window_buffer_inst + ( + .CLK(ap_clk), + .RST(ap_rst_n), + .data_in(window_buffer_in), + .shift_enable(window_buffer_shift_enable), + .data_out(window_buffer_out) + ); + + //controller instantiation + uwire advance_controller; + uwire signed [INCR_BITWIDTH-1:0] addr_incr; + uwire [INCR_BITWIDTH-1:0] tail_incr; + $TOP_MODULE_NAME$_controller controller_inst ( + .clk(ap_clk), + .rst_n(ap_rst_n), + .advance(advance_controller), + .addr_incr(addr_incr), + .tail_incr(tail_incr) + ); + + // Counters/address registers + // Add a sign bit even to (most) unsigned counters and Window_buffer_read_addr_reg, + // so we can use automatic sign extension and simplify calculations w/ signed increment. + // Alternatively, we could manually sign-extend and shave off a bit here or there. + logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] Newest_buffered_elem = -1; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] Current_elem = FIRST_WRITE_ELEM; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] First_elem_next_window = 0; + + // Control signals/registers + logic Writing_done = 0; + logic write_done = 0; + + uwire write_ok = write_cmd && (out_V_V_TREADY || write_done); + uwire write_blocked = write_cmd && !out_V_V_TREADY && !write_done; + + uwire write_cmd = !($signed(Current_elem) > Newest_buffered_elem) && !Writing_done;; + + uwire reading_done = Newest_buffered_elem == LAST_READ_ELEM; + uwire read_cmd = + !reading_done && ( // if there is still an input element left to read + Writing_done || ( // if fetching is done (e.g. for skipped rows at FM end due to stride) + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(First_elem_next_window) && + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(Current_elem) + ) // (over-)write to buffer if oldest buffered element will no longer be needed + ); + uwire read_ok = read_cmd && in0_V_V_TVALID && !write_blocked; + + // includes waiting on W if W-only cycle: wait only on W no R/W to wait for + uwire advance = read_ok || (!read_cmd && write_ok) || (!read_cmd && !write_cmd); + + //assign buffer control + assign window_buffer_shift_enable = advance; + assign advance_controller = write_ok; + + //assign I/O ports + assign window_buffer_in = in0_V_V_TDATA; + assign out_V_V_TDATA = window_buffer_out; + assign in0_V_V_TREADY = ap_rst_n && read_ok; //only asserted if data is available and we can store it (allowed) + assign out_V_V_TVALID = ap_rst_n && write_cmd && !write_done; //only asserted if we have data available and it has not been read yet (don't wait for READY from sink) + + //write done logic + always_ff @(posedge ap_clk) begin + if (advance) begin + write_done <= 1'b0; //reset flag + end else if (write_ok) // successful W in this cycle, but R still outstanding + write_done <= 1'b1; //write can happen even if read is blocked, but only for the current cycle! + end + + //main process for advancing counters + always_ff @(posedge ap_clk) begin + if(!ap_rst_n) begin + Newest_buffered_elem <= -1; + Current_elem <= FIRST_WRITE_ELEM; + First_elem_next_window <= 0; + Writing_done <= 0; + end + else begin + if (read_ok) begin + Newest_buffered_elem <= Newest_buffered_elem+1; + + //check if this is the last read cycle (reading_done will be true afterwards) + if ((Newest_buffered_elem == LAST_READ_ELEM-1) && Writing_done) begin + //start processing of next FM if writing is done already (possible due to unused input elements at the tail end) + //todo: allow for read overlapping between feature maps (i.e., reading first elements from next FM while still writing last window of current FM) + Newest_buffered_elem <= -1; + Current_elem <= FIRST_WRITE_ELEM; + First_elem_next_window <= 0; + Writing_done <= 0; + end + end + + if (write_ok) begin + First_elem_next_window <= First_elem_next_window + tail_incr; + + //check if this is the last write cycle (Writing_done will be true afterwards) + if (Current_elem == LAST_WRITE_ELEM) begin + Writing_done <= 1; + + if (reading_done || (read_ok && (Newest_buffered_elem == LAST_READ_ELEM - 1))) begin + //start processing of next FM if reading is done already, or completes in the same cycle + Newest_buffered_elem <= -1; + Current_elem <= FIRST_WRITE_ELEM; + First_elem_next_window <= 0; + Writing_done <= 0; + end + end + else + Current_elem <= $signed(Current_elem) + addr_incr; + end + end + end + +endmodule : $TOP_MODULE_NAME$_impl diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py index 1afd23d3a1709a8929a03c21a6eba0a5a8cd6ba6..1ae4022b796ba4b3968dfa27c2a5497a147f7cea 100755 --- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py +++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py @@ -72,8 +72,8 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): "SIMD": ("i", True, 0), # additional parallelization parameter - not yet implemented "M": ("i", False, 1), - # alternative implementation style - not yet implemented - "parallel_window": ("i", False, 0, {0}), + # Enable parallel window output (requires full SIMD unfolding) + "parallel_window": ("i", False, 0, {0, 1}), "Stride": ("ints", True, []), # [H, W] = [Y, X] "Dilation": ("ints", True, []), # [H, W] = [Y, X] # FINN DataTypes for inputs, weights, outputs @@ -639,6 +639,281 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): return template_path, code_gen_dict + def prepare_codegen_parallel(self): + # Parallel implementation style for MMV_out = K: + # mix of shift-registers (for parallel read) and line buffers (BRAM or LUTRAM) + # compute a static schedule by analyzing access pattern (from im2col function) + template_path = ( + os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_parallel.sv" + ) + code_gen_dict = {} + + ifm_ch = self.get_nodeattr("IFMChannels") + k = self.get_nodeattr("ConvKernelDim") + ifm_dim = self.get_nodeattr("IFMDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + simd = self.get_nodeattr("SIMD") + M = self.get_nodeattr("M") + + k_h, k_w = k + h, w = ifm_dim + pad = [0, 0, 0, 0] # padding happens in separate padding node for now + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + pad_h = pad[0] + pad[2] + pad_w = pad[1] + pad[3] + out_dim_h = im2col.compute_conv_output_dim(h, k_h, stride_h, pad_h, dilation_h) + out_dim_w = im2col.compute_conv_output_dim(w, k_w, stride_w, pad_w, dilation_w) + mmv_in = M * 1 + mmv_out = M * k_h * k_w + channel_factor = int(ifm_ch / simd) + + # compute minimal buffer length (assuming it holds 1 complete window) + buffer_min_size = ( + (k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1 + ) * channel_factor + + # buffer_actual_size = self.get_buffer_depth() # TODO: Move to this method + buffer_actual_size = buffer_min_size + 1 + code_gen_dict["$BUF_ELEM_TOTAL$"] = [str(buffer_actual_size)] + + # compute some intermediate values, e.g., kernel "width" = k_w incl. dilation + # or cols/rows that are skipped due to imperfect stride<->dim combination + kernel_width = (k_w - 1) * dilation_w + 1 + kernel_height = (k_h - 1) * dilation_h + 1 + skip_columns = w % (kernel_width + (out_dim_w - 1) * stride_w) + skip_rows = h % (kernel_height + (out_dim_h - 1) * stride_h) + + # compute address increment values for 5-loop nest #TODO: simplify + addr_incr_end_simd = 1 + addr_incr_end_window_elem = (dilation_w - 1) * channel_factor + 1 + addr_incr_end_window_row = ( + ((w - kernel_width) * channel_factor) # remaining line + + ((dilation_h - 1) * w * channel_factor) # skip lines + + 1 # wrap-around of minimally sized buffer + ) + addr_incr_end_window = -buffer_min_size + stride_w * channel_factor + 1 + addr_incr_end_row = ( + -buffer_min_size + + ((skip_columns + kernel_width) * channel_factor) # remaining line + + ((stride_h - 1) * w * channel_factor) # skip lines + + 1 + ) + + # set certain threshold indices to detect when reading/writing finishes + code_gen_dict["$LAST_READ_ELEM$"] = [str(h * w * channel_factor - 1)] + code_gen_dict["$LAST_WRITE_ELEM$"] = [ + str(((h - skip_rows - 1) * w + (w - skip_columns)) * channel_factor - 1) + ] + + # default controller loop structure: # iterations (counters) map directly + loop_h_iterations = out_dim_h + loop_w_iterations = out_dim_w # -> innermost loop + loop_kh_iterations = 1 # k_h + loop_kw_iterations = 1 # k_w + loop_simd_iterations = 1 # channel_factor + + if loop_w_iterations == 1: + code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_H"] + loop_h_iterations -= 1 # -1 because state is initial state + else: + code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_W"] + loop_w_iterations -= 1 # -1 because state is initial state + + tail_incr_w = addr_incr_end_window + buffer_min_size - 1 + tail_incr_h = addr_incr_end_row + buffer_min_size - 1 + tail_incr_last_window = buffer_min_size - 1 + code_gen_dict["$IS_DEPTHWISE$"] = ["0"] + + # overwrite new loop bounds: + addr_incr_end_simd = 1 + addr_incr_end_window_elem = 1 + addr_incr_end_window_row = 1 + addr_incr_end_window = tail_incr_w + addr_incr_end_row = tail_incr_h + + # add init value for CURRENT_ELEM counter = last elem of first window + code_gen_dict["$FIRST_WRITE_ELEM$"] = [str(buffer_min_size - 1)] + + cntr_bitwidth = math.ceil( + math.log2( + max( + loop_h_iterations - 2 + 1, + loop_w_iterations - 2 + 1, + loop_kh_iterations - 2 + 1, + loop_kw_iterations - 2 + 1, + loop_simd_iterations - 2 + 1, + ) + ) + ) + code_gen_dict["$CNTR_BITWIDTH$"] = [str(cntr_bitwidth)] + code_gen_dict["$LOOP_H_ITERATIONS$"] = [str(loop_h_iterations - 2)] + code_gen_dict["$LOOP_W_ITERATIONS$"] = [str(loop_w_iterations - 2)] + code_gen_dict["$LOOP_KH_ITERATIONS$"] = [str(loop_kh_iterations - 2)] + code_gen_dict["$LOOP_KW_ITERATIONS$"] = [str(loop_kw_iterations - 2)] + code_gen_dict["$LOOP_SIMD_ITERATIONS$"] = [str(loop_simd_iterations - 2)] + + incr_bitwidth = 1 + math.ceil( + math.log2( + max( + abs(addr_incr_end_simd) + 1, + abs(addr_incr_end_window_elem) + 1, + abs(addr_incr_end_window_row) + 1, + abs(addr_incr_end_window) + 1, + abs(addr_incr_end_row) + 1, + abs(tail_incr_w) + 1, + abs(tail_incr_h) + 1, + abs(tail_incr_last_window) + 1, + ) + ) + ) + code_gen_dict["$INCR_BITWIDTH$"] = [str(incr_bitwidth)] + code_gen_dict["$HEAD_INCR_SIMD$"] = [str(addr_incr_end_simd)] + code_gen_dict["$HEAD_INCR_KW$"] = [str(addr_incr_end_window_elem)] + code_gen_dict["$HEAD_INCR_KH$"] = [str(addr_incr_end_window_row)] + code_gen_dict["$HEAD_INCR_W$"] = [str(addr_incr_end_window)] + code_gen_dict["$HEAD_INCR_H$"] = [str(addr_incr_end_row)] + code_gen_dict["$TAIL_INCR_W$"] = [str(tail_incr_w)] + code_gen_dict["$TAIL_INCR_H$"] = [str(tail_incr_h)] + code_gen_dict["$TAIL_INCR_LAST$"] = [str(tail_incr_last_window)] + + code_gen_dict["$SIMD$"] = [str(simd)] + code_gen_dict["$MMV_IN$"] = [str(mmv_in)] + code_gen_dict["$MMV_OUT$"] = [str(mmv_out)] + + # prepare buffer partitioning into "reg_fifos" and "bram_fifos" + # use normalized ([H,W]=[1,W]) dimensions for 1D case + ( + ifm_ch, + [ifm_dim_h, ifm_dim_w], + [ofm_dim_h, ofm_dim_w], + [k_h, k_w], + [stride_h, stride_w], + [dilation_h, dilation_w], + ) = self.get_1d_conv_attrs_normalized() + + reg_fifos = [] + bram_fifos_depth = [] + + px_idx = 0 + for ky in range(k_h): + reg_fifo = [] + for kx in range(k_w): + reg_fifo.append(px_idx) + px_idx += 1 + if kx < (k_w - 1): + reg_fifo.extend([-1] * (dilation_w - 1)) + px_idx += dilation_w - 1 + reg_fifos.append(reg_fifo) + + if ky < (k_h - 1): + line_buffer_len = (w - kernel_width) + w * (dilation_h - 1) + bram_fifos_depth.append(line_buffer_len) + px_idx += line_buffer_len + + code_gen_dict["$GENERATE_REG_FIFOS$"] = [] + for i, reg_fifo in enumerate(reg_fifos): + code_gen_dict["$GENERATE_REG_FIFOS$"].append( + """ + wire [IN_WIDTH-1:0] reg_fifo_{id}_in; + wire [IN_WIDTH-1:0] reg_fifo_{id}_out; + wire [IN_WIDTH*{len}-1:0] reg_fifo_{id}; + {name}_reg_buffer + #( + .WIDTH(IN_WIDTH), + .DEPTH({len}) + ) + reg_buffer_inst_{id} + ( + .CLK(CLK), + .shift_enable(shift_enable), + .shift_in(reg_fifo_{id}_in), + .shift_out(reg_fifo_{id}_out), + .data_out(reg_fifo_{id}) + );""".format( + name=self.get_verilog_top_module_name(), + id=i, + len=len(reg_fifo), + ) + ) + + code_gen_dict["$GENERATE_BRAM_FIFOS$"] = [] + for i, bram_fifo_depth in enumerate(bram_fifos_depth): + code_gen_dict["$GENERATE_BRAM_FIFOS$"].append( + """ + wire [IN_WIDTH-1:0] bram_fifo_{id}_in; + wire [IN_WIDTH-1:0] bram_fifo_{id}_out; + {name}_ram_buffer + #( + .WIDTH(IN_WIDTH), + .DEPTH({len}) + ) + ram_buffer_inst_{id} + ( + .CLK(CLK), + .RST(RST), + .shift_enable(shift_enable), + .shift_in(bram_fifo_{id}_in), + .shift_out(bram_fifo_{id}_out) + );""".format( + name=self.get_verilog_top_module_name(), + id=i, + len=bram_fifo_depth, + ) + ) + + code_gen_dict["$GENERATE_OUTPUT_MAPPING$"] = [] + out_idx = mmv_out - 1 + for fifo_id, reg_fifo in enumerate(reg_fifos): + for fifo_idx, access_idx in enumerate(reg_fifo): + if access_idx != -1: + code_gen_dict["$GENERATE_OUTPUT_MAPPING$"].append( + """assign data_out[OUT_ELEM_WIDTH*{out_idx}+:OUT_ELEM_WIDTH] + = reg_fifo_{fifo_id}[{access_idx}*{mmv}*OUT_ELEM_WIDTH+ + OUT_ELEM_WIDTH*{mmv_idx}+:OUT_ELEM_WIDTH];""".format( + out_idx=out_idx, + fifo_id=fifo_id, + access_idx=len(reg_fifo) + - 1 + - int((max(reg_fifo) - access_idx) / M), + mmv_idx=(max(reg_fifo) - access_idx) % M, + mmv=M, + ) + ) + # reversal: out_idx=0 -> oldest buffer element -> highest access_idx + out_idx = out_idx - 1 + assert out_idx == -1, "ERROR: Not all output vector elements connected" + + code_gen_dict["$GENERATE_BUFFER_CONNECTION$"] = [] + for i in range(len(reg_fifos)): + if i == 0: + # first FIFO containing newest elements -> input comes from input reg + code_gen_dict["$GENERATE_BUFFER_CONNECTION$"].append( + """assign reg_fifo_{fifo_id}_in = data_in;""".format( + fifo_id=i, + ) + ) + else: + # other REG FIFOs -> input comes from connected BRAM FIFO (line buffer) + input_fifo_id = i - 1 + code_gen_dict["$GENERATE_BUFFER_CONNECTION$"].append( + """assign reg_fifo_{fifo_id}_in = bram_fifo_{input_fifo_id}_out; + """.format( + fifo_id=i, input_fifo_id=input_fifo_id + ) + ) + for i in range(len(bram_fifos_depth)): + input_fifo_id = i + code_gen_dict["$GENERATE_BUFFER_CONNECTION$"].append( + """assign bram_fifo_{fifo_id}_in = reg_fifo_{input_fifo_id}_out; + """.format( + fifo_id=i, input_fifo_id=input_fifo_id + ) + ) + + return template_path, code_gen_dict + def select_impl_style(self): simd = self.get_nodeattr("SIMD") M = self.get_nodeattr("M") @@ -685,9 +960,6 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): else: impl_style = "default" - assert ( - impl_style == "default" - ), "ERROR: Parallel window mode not yet implemented" return impl_style def generate_hdl(self): @@ -696,6 +968,8 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): # prepare code generation by filling out dictionaries if impl_style == "default": template_path, code_gen_dict = self.prepare_codegen_default() + elif impl_style == "parallel": + template_path, code_gen_dict = self.prepare_codegen_parallel() else: raise Exception("Requested impl. style not implemented")