diff --git a/finn-rtllib/swg/swg_template_default.sv b/finn-rtllib/swg/swg_template_default.sv index fc4c96d1c32485810f0a50844a8003978716a708..2d255a35edf97e28053545b89512a4b1415b6f57 100644 --- a/finn-rtllib/swg/swg_template_default.sv +++ b/finn-rtllib/swg/swg_template_default.sv @@ -6,7 +6,9 @@ module $TOP_MODULE_NAME$_controller #( int unsigned LOOP_SIMD_ITERATIONS = $LOOP_SIMD_ITERATIONS$, int unsigned INCR_BITWIDTH = $INCR_BITWIDTH$, - bit [INCR_BITWIDTH-1:0] ADDR_INCREMENT_MAP[6] = $ADDR_INCREMENT_MAP$ + bit [INCR_BITWIDTH-1:0] ADDR_INCREMENT_MAP[6] = $ADDR_INCREMENT_MAP$, + + bit IS_DEPTHWISE = $IS_DEPTHWISE$ )( input logic clk, input logic rst_n, @@ -16,7 +18,7 @@ module $TOP_MODULE_NAME$_controller #( output logic [INCR_BITWIDTH-1:0] tail_incr ); - //State and counters + // state and counters typedef enum logic [2:0] { STATE_START, STATE_LOOP_SIMD, @@ -28,66 +30,83 @@ module $TOP_MODULE_NAME$_controller #( state_e State = $INNERMOST_STATE$; state_e state_next; - logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] counter_loop_h = LOOP_H_ITERATIONS-1; - logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] counter_loop_w = LOOP_W_ITERATIONS-1; - logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] counter_loop_kh = LOOP_KH_ITERATIONS-1; - logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] counter_loop_kw = LOOP_KW_ITERATIONS-1; - logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] counter_loop_simd = LOOP_SIMD_ITERATIONS-1; + logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] Counter_loop_h = LOOP_H_ITERATIONS-1; + logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] Counter_loop_w = LOOP_W_ITERATIONS-1; + logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] Counter_loop_kh = LOOP_KH_ITERATIONS-1; + logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] Counter_loop_kw = LOOP_KW_ITERATIONS-1; + logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] Counter_loop_simd = LOOP_SIMD_ITERATIONS-1; logic [INCR_BITWIDTH-1:0] tail_incr_reg = 'x; assign addr_incr = ADDR_INCREMENT_MAP[State]; assign tail_incr = tail_incr_reg; - //combinational logic for tail_incr generation - $TAIL_INCR_GENERATION$ + // combinational logic for tail_incr generation + uwire tail_incr_inner_condition; + generate + if (IS_DEPTHWISE) + assign tail_incr_inner_condition = (Counter_loop_kh >= 0); + else + assign tail_incr_inner_condition = 0; + endgenerate + + always @ (tail_incr_inner_condition, Counter_loop_w, Counter_loop_h) begin + if (tail_incr_inner_condition) + tail_incr_reg = 1; + else if (Counter_loop_w >= 0) + tail_incr_reg = $TAIL_INCR_W$; + else if (Counter_loop_h >= 0) + tail_incr_reg = $TAIL_INCR_H$; + else + tail_incr_reg = $TAIL_INCR_LAST$; + end - //combinational next state logic + // combinational next state logic always_comb begin : blkState state_next = State; if(State != $INNERMOST_STATE$) state_next = $INNERMOST_STATE$; else begin - if(counter_loop_simd < 0) begin + if(Counter_loop_simd < 0) begin state_next = - (counter_loop_kw >= 0)? STATE_LOOP_KW : - (counter_loop_kh >= 0)? STATE_LOOP_KH : - (counter_loop_w >= 0)? STATE_LOOP_W : - (counter_loop_h >= 0)? STATE_LOOP_H : + (Counter_loop_kw >= 0)? STATE_LOOP_KW : + (Counter_loop_kh >= 0)? STATE_LOOP_KH : + (Counter_loop_w >= 0)? STATE_LOOP_W : + (Counter_loop_h >= 0)? STATE_LOOP_H : /* else */ STATE_START; end end end : blkState - //sequential logic + // sequential logic always_ff @ (posedge clk) begin if(!rst_n) begin State <= $INNERMOST_STATE$; - counter_loop_h <= LOOP_H_ITERATIONS-1; - counter_loop_w <= LOOP_W_ITERATIONS-1; - counter_loop_kh <= LOOP_KH_ITERATIONS-1; - counter_loop_kw <= LOOP_KW_ITERATIONS-1; - counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; + Counter_loop_h <= LOOP_H_ITERATIONS-1; + Counter_loop_w <= LOOP_W_ITERATIONS-1; + Counter_loop_kh <= LOOP_KH_ITERATIONS-1; + Counter_loop_kw <= LOOP_KW_ITERATIONS-1; + Counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; end else if(advance) begin State <= state_next; if (State == $INNERMOST_STATE$) begin - if(counter_loop_simd >= 0) counter_loop_simd <= counter_loop_simd-1; + if(Counter_loop_simd >= 0) Counter_loop_simd <= Counter_loop_simd-1; else begin - counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; - if(counter_loop_kw >= 0) counter_loop_kw <= counter_loop_kw-1; + Counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; + if(Counter_loop_kw >= 0) Counter_loop_kw <= Counter_loop_kw-1; else begin - counter_loop_kw <= LOOP_KW_ITERATIONS-1; - if(counter_loop_kh >= 0) counter_loop_kh <= counter_loop_kh-1; + Counter_loop_kw <= LOOP_KW_ITERATIONS-1; + if(Counter_loop_kh >= 0) Counter_loop_kh <= Counter_loop_kh-1; else begin - counter_loop_kh <= LOOP_KH_ITERATIONS-1; - if(counter_loop_w >= 0) counter_loop_w <= counter_loop_w-1; + Counter_loop_kh <= LOOP_KH_ITERATIONS-1; + if(Counter_loop_w >= 0) Counter_loop_w <= Counter_loop_w-1; else begin - counter_loop_w <= LOOP_W_ITERATIONS-1; - if(counter_loop_h >= 0) counter_loop_h <= counter_loop_h-1; - else counter_loop_h <= LOOP_H_ITERATIONS-1; - end - end + Counter_loop_w <= LOOP_W_ITERATIONS-1; + if(Counter_loop_h >= 0) Counter_loop_h <= Counter_loop_h-1; + else Counter_loop_h <= LOOP_H_ITERATIONS-1; + end + end end - end + end end end end @@ -112,7 +131,7 @@ module $TOP_MODULE_NAME$_cyclic_buffer_addressable #( $RAM_STYLE$ logic [WIDTH-1:0] Ram[DEPTH]; logic [WIDTH-1:0] Out = 'x; - always_ff @(posedge clk) begin + always_ff @(posedge clk) begin if (!rst_n) begin Out <= 'x; end @@ -126,10 +145,10 @@ module $TOP_MODULE_NAME$_cyclic_buffer_addressable #( endmodule : $TOP_MODULE_NAME$_cyclic_buffer_addressable module $TOP_MODULE_NAME$_impl #( - int BIT_WIDTH = $BIT_WIDTH$, - int SIMD = $SIMD$, - int MMV_IN = $MMV_IN$, - int MMV_OUT = $MMV_OUT$, + int BIT_WIDTH, + int SIMD, + int MMV_IN, + int MMV_OUT, int LAST_READ_ELEM = $LAST_READ_ELEM$, int LAST_WRITE_ELEM = $LAST_WRITE_ELEM$, int BUF_ELEM_TOTAL = $BUF_ELEM_TOTAL$, @@ -147,12 +166,12 @@ module $TOP_MODULE_NAME$_impl #( input logic out_V_V_TREADY, output logic [BIT_WIDTH * SIMD * MMV_OUT-1:0] out_V_V_TDATA ); - // Derived Constants + // derived Constants localparam int unsigned BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; localparam int unsigned BUF_OUT_ELEM_WIDTH = BIT_WIDTH * SIMD; localparam int unsigned BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; - //main buffer instantiation + // main buffer instantiation uwire [BUF_IN_WIDTH -1:0] window_buffer_in; uwire [BUF_OUT_WIDTH-1:0] window_buffer_out; uwire window_buffer_write_enable; @@ -188,38 +207,38 @@ module $TOP_MODULE_NAME$_impl #( ); // Counters/address registers - // Add a sign bit even to (most) unsigned counters and window_buffer_read_addr_reg, + // Add a sign bit even to (most) unsigned counters and Window_buffer_read_addr_reg, // so we can use automatic sign extension and simplify calculations w/ signed increment. // Alternatively, we could manually sign-extend and shave off a bit here or there. - logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] newest_buffered_elem = -1; - logic [$clog2(LAST_READ_ELEM+1)+1-1:0] current_elem = 0; - logic [$clog2(LAST_READ_ELEM+1)+1-1:0] first_elem_next_window = 0; - logic [$clog2(ELEM_PER_WINDOW) -1:0] k = 0; - logic [$clog2(BUF_ELEM_TOTAL)+1 -1:0] window_buffer_read_addr_reg = 0; - logic [$clog2(BUF_ELEM_TOTAL)-1:0] window_buffer_write_addr_reg = 0; + logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] Newest_buffered_elem = -1; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] Current_elem = 0; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] First_elem_next_window = 0; + logic [$clog2(ELEM_PER_WINDOW) -1:0] K = 0; + logic [$clog2(BUF_ELEM_TOTAL)+1 -1:0] Window_buffer_read_addr_reg = 0; + logic [$clog2(BUF_ELEM_TOTAL)-1:0] Window_buffer_write_addr_reg = 0; // Control signals/registers uwire read_cmd = !reading_done && ( // if there is still an input element left to read - fetching_done || ( // if fetching is done (e.g. for skipped rows at FM end due to stride) - $signed(((newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(first_elem_next_window) && - $signed(((newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(current_elem) - ) // (over-)write to buffer if oldest buffered element will no longer be needed - ); + Fetching_done || ( // if fetching is done (e.g. for skipped rows at FM end due to stride) + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(First_elem_next_window) && + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(Current_elem) + ) // (over-)write to buffer if oldest buffered element will no longer be needed + ); uwire read_ok = read_cmd && in0_V_V_TVALID; - uwire reading_done = newest_buffered_elem == LAST_READ_ELEM; + uwire reading_done = Newest_buffered_elem == LAST_READ_ELEM; - uwire fetch_cmd = !($signed(current_elem) > newest_buffered_elem) && !write_blocked && !fetching_done; - logic fetching_done = 0; + uwire fetch_cmd = !($signed(Current_elem) > Newest_buffered_elem) && !write_blocked && !Fetching_done; + logic Fetching_done = 0; - logic write_cmd = 0; - logic writing_done = 0; - uwire write_ok = write_cmd && out_V_V_TREADY; - uwire write_blocked = write_cmd && !out_V_V_TREADY;; + logic Write_cmd = 0; + logic Writing_done = 0; + uwire write_ok = Write_cmd && out_V_V_TREADY; + uwire write_blocked = Write_cmd && !out_V_V_TREADY;; //assign buffer control - assign window_buffer_write_addr = window_buffer_write_addr_reg; - assign window_buffer_read_addr = window_buffer_read_addr_reg; + assign window_buffer_write_addr = Window_buffer_write_addr_reg; + assign window_buffer_read_addr = Window_buffer_read_addr_reg; assign window_buffer_write_enable = read_ok; assign window_buffer_read_enable = fetch_cmd; assign advance_controller = fetch_cmd; @@ -228,87 +247,87 @@ module $TOP_MODULE_NAME$_impl #( assign window_buffer_in = in0_V_V_TDATA; assign out_V_V_TDATA = window_buffer_out; assign in0_V_V_TREADY = ap_rst_n && read_ok; //only asserted if data is available and we can store it (allowed) - assign out_V_V_TVALID = ap_rst_n && write_cmd; //only asserted if we have data available and it has not been read yet (don't wait for READY from sink) + assign out_V_V_TVALID = ap_rst_n && Write_cmd; //only asserted if we have data available and it has not been read yet (don't wait for READY from sink) //main process for advancing counters always_ff @(posedge ap_clk) begin if(!ap_rst_n) begin - newest_buffered_elem <= -1; - current_elem <= 0; - first_elem_next_window <= 0; - k <= 0; - window_buffer_read_addr_reg <= 0; - window_buffer_write_addr_reg <= 0; - fetching_done <= 0; - write_cmd <= 0; - writing_done <= 0; + Newest_buffered_elem <= -1; + Current_elem <= 0; + First_elem_next_window <= 0; + K <= 0; + Window_buffer_read_addr_reg <= 0; + Window_buffer_write_addr_reg <= 0; + Fetching_done <= 0; + Write_cmd <= 0; + Writing_done <= 0; end else begin if (read_ok) begin - window_buffer_write_addr_reg <= (window_buffer_write_addr_reg == BUF_ELEM_TOTAL-1)? 0 : window_buffer_write_addr_reg + 1; - newest_buffered_elem <= newest_buffered_elem+1; + Window_buffer_write_addr_reg <= (Window_buffer_write_addr_reg == BUF_ELEM_TOTAL-1)? 0 : Window_buffer_write_addr_reg + 1; + Newest_buffered_elem <= Newest_buffered_elem+1; - if (newest_buffered_elem == LAST_READ_ELEM-1) begin - window_buffer_write_addr_reg <= 0; + if (Newest_buffered_elem == LAST_READ_ELEM-1) begin + Window_buffer_write_addr_reg <= 0; end //check if this is the last read cycle (reading_done will be true afterwards) - if ((newest_buffered_elem == LAST_READ_ELEM-1) && writing_done) begin + if ((Newest_buffered_elem == LAST_READ_ELEM-1) && Writing_done) begin //start processing of next FM if writing is done already (possible due to unused input elements at the tail end) //todo: allow for read overlapping between feature maps (i.e., reading first elements from next FM while still writing last window of current FM) - newest_buffered_elem <= -1; - current_elem <= 0; - window_buffer_read_addr_reg <= 0; - first_elem_next_window <= 0; - writing_done <= 0; - fetching_done <= 0; + Newest_buffered_elem <= -1; + Current_elem <= 0; + Window_buffer_read_addr_reg <= 0; + First_elem_next_window <= 0; + Writing_done <= 0; + Fetching_done <= 0; end end - + if (fetch_cmd) begin //count up to track which element index is about to be read from the buffer, and where it is located within the buffer //use increment value calculated by controller // absolute buffer address wrap-around - automatic logic signed [$clog2(BUF_ELEM_TOTAL)+1:0] ra = $signed(window_buffer_read_addr_reg) + $signed(addr_incr); + automatic logic signed [$clog2(BUF_ELEM_TOTAL)+1:0] ra = $signed(Window_buffer_read_addr_reg) + $signed(addr_incr); automatic logic signed [$clog2(BUF_ELEM_TOTAL+1):0] ra_correct = (ra >= BUF_ELEM_TOTAL)? -BUF_ELEM_TOTAL : (ra < 0)? BUF_ELEM_TOTAL : 0; - window_buffer_read_addr_reg <= ra + ra_correct; + Window_buffer_read_addr_reg <= ra + ra_correct; //keep track where we are within a window - k <= (k != ELEM_PER_WINDOW - 1)? k+1 : 0; + K <= (K != ELEM_PER_WINDOW - 1)? K+1 : 0; //update first element of next window to allow buffer overwrite up until that point - if (k == 0) - first_elem_next_window <= first_elem_next_window + tail_incr; + if (K == 0) + First_elem_next_window <= First_elem_next_window + tail_incr; - //check if this is the last write cycle (writing_done will be true afterwards) - if (current_elem == LAST_WRITE_ELEM) - fetching_done <= 1; + //check if this is the last write cycle (Writing_done will be true afterwards) + if (Current_elem == LAST_WRITE_ELEM) + Fetching_done <= 1; else - current_elem <= $signed(current_elem) + addr_incr; + Current_elem <= $signed(Current_elem) + addr_incr; // determine if prefetched data will be outstanding in the next cycle // if we fetch in this cycle -> yes // if we do not fetch nor write -> do not change // if we do not fetch but write successfully-> clear outstanding data - write_cmd <= fetch_cmd; - end + Write_cmd <= fetch_cmd; + end if (write_ok) - write_cmd <= fetch_cmd; + Write_cmd <= fetch_cmd; - if (write_ok && fetching_done) begin - //check if this is the last write cycle (writing_done will be true afterwards) - if (reading_done || (read_ok && (newest_buffered_elem == LAST_READ_ELEM - 1))) begin + if (write_ok && Fetching_done) begin + //check if this is the last write cycle (Writing_done will be true afterwards) + if (reading_done || (read_ok && (Newest_buffered_elem == LAST_READ_ELEM - 1))) begin //start processing of next FM if reading is done already, or completes in the same cycle - newest_buffered_elem <= -1; - current_elem <= 0; - window_buffer_read_addr_reg <= 0; - first_elem_next_window <= 0; - fetching_done <= 0; + Newest_buffered_elem <= -1; + Current_elem <= 0; + Window_buffer_read_addr_reg <= 0; + First_elem_next_window <= 0; + Fetching_done <= 0; end else - writing_done <= 1; + Writing_done <= 1; end end end diff --git a/finn-rtllib/swg/swg_template_parallel.sv b/finn-rtllib/swg/swg_template_parallel.sv deleted file mode 100755 index 19638d8a1dc4cb0afda6319bc9e58f38fa494269..0000000000000000000000000000000000000000 --- a/finn-rtllib/swg/swg_template_parallel.sv +++ /dev/null @@ -1,409 +0,0 @@ -`timescale 1 ns / 1 ps - -module $TOP_MODULE_NAME$_controller -( - CLK, - RST, - advance, - cmd_read, - cmd_write -); - -input CLK; -input RST; -input advance; -output cmd_read; -output cmd_write; - -////code generation part: -//mapping of R/W command values to each state (START, MAIN_1, MAIN_2, INTER_1, INTER_2, END_1, END_2) -localparam [0:6] READ_CMD_MAP = $READ_CMD_MAP$; -localparam [0:6] WRITE_CMD_MAP = $WRITE_CMD_MAP$; - -localparam START_COUNTER = $START_COUNTER$; -localparam LOOP_MAIN_COUNTER = $LOOP_MAIN_COUNTER$; -localparam LOOP_MAIN_1_COUNTER = $LOOP_MAIN_1_COUNTER$; -localparam LOOP_MAIN_2_COUNTER = $LOOP_MAIN_2_COUNTER$; -localparam LOOP_INTER_COUNTER = $LOOP_INTER_COUNTER$; -localparam LOOP_INTER_1_COUNTER = $LOOP_INTER_1_COUNTER$; -localparam LOOP_INTER_2_COUNTER = $LOOP_INTER_2_COUNTER$; -localparam LOOP_END_1_COUNTER = $LOOP_END_1_COUNTER$; -localparam LOOP_END_2_COUNTER = $LOOP_END_2_COUNTER$; -//// - -//state and counters -reg [2:0] state, state_next; -parameter STATE_START = 0, STATE_LOOP_MAIN_1 = 1, STATE_LOOP_MAIN_2 = 2, STATE_LOOP_INTER_1 = 3, STATE_LOOP_INTER_2 = 4, STATE_END_1 = 5, STATE_END_2 = 6; -integer counter_current; //todo: minimize width -integer counter_loop_main; -integer counter_loop_inter; - -assign cmd_read = READ_CMD_MAP[state_next]; //read command indicates read in *upcoming* cycle, due to how schedule is constructed -assign cmd_write = WRITE_CMD_MAP[state]; - -//combinational next state logic -always @ (state, counter_current, counter_loop_main, counter_loop_inter) begin - state_next = state; //default - case (state) - STATE_START: - if (counter_current == START_COUNTER-1) - state_next = STATE_LOOP_MAIN_1; - - STATE_LOOP_MAIN_1: - if (counter_current == LOOP_MAIN_1_COUNTER-1) - state_next = STATE_LOOP_MAIN_2; - - STATE_LOOP_MAIN_2: begin - if (counter_current == LOOP_MAIN_2_COUNTER-1) begin - state_next = STATE_LOOP_MAIN_1; - if (counter_loop_main == LOOP_MAIN_COUNTER-1) begin - //no -1 because this counter marks the currently active iteration, not finished iterations - if ((LOOP_INTER_COUNTER != 0) && (counter_loop_inter != LOOP_INTER_COUNTER)) - state_next = STATE_LOOP_INTER_1; - else begin - //there might not be an end sequence -> restart immediately - if (LOOP_END_1_COUNTER != 0) - state_next = STATE_END_1; - else - state_next = STATE_LOOP_MAIN_2; //wait in current state until reset - end - end - end - end - - STATE_LOOP_INTER_1: begin - if (counter_current == LOOP_INTER_1_COUNTER-1) begin - if (LOOP_INTER_2_COUNTER != 0) - state_next = STATE_LOOP_INTER_2; - else - state_next = STATE_LOOP_MAIN_1; - end - end - - STATE_LOOP_INTER_2: - if (counter_current == LOOP_INTER_2_COUNTER-1) - state_next = STATE_LOOP_MAIN_1; - - STATE_END_1: begin - if (counter_current == LOOP_END_1_COUNTER-1) begin - if (LOOP_END_2_COUNTER != 0) - state_next = STATE_END_2; - else - state_next = STATE_END_1; //wait in current state until reset - end - end - - STATE_END_2: - if (counter_current == LOOP_END_2_COUNTER-1) - state_next = STATE_END_2; //wait in current state until reset - endcase -end - -//sequential logic -always @ (posedge CLK) begin - if (RST) begin - counter_current <= -1; - counter_loop_main <= 0; - counter_loop_inter <= 0; - state <= STATE_START; - end else begin - if (advance) begin - counter_current <= counter_current+1; - state <= state_next; - - if (state != state_next) begin - counter_current <= 0; - - //count up main loop upon re-entering this loop (not on first enter from start) - if ((state_next == STATE_LOOP_MAIN_1) && (state != STATE_START)) begin - if (counter_loop_main == LOOP_MAIN_COUNTER-1) begin - counter_loop_main <= 0; - end else begin - counter_loop_main <= counter_loop_main+1; - end - end - - if (state_next == STATE_LOOP_INTER_1) begin - if (counter_loop_inter == LOOP_INTER_COUNTER) begin //no -1 because this counter marks the currently active iteration, not finished iterations - counter_loop_inter <= 0; - end else begin - counter_loop_inter <= counter_loop_inter+1; - end - end - end - end - end -end -endmodule //controller - -module $TOP_MODULE_NAME$_reg_buffer -#( - parameter WIDTH = 1, - parameter DEPTH = 1 -) -( - CLK, - shift_enable, - shift_in, - shift_out, - data_out -); - -input CLK, shift_enable; -input [WIDTH-1:0] shift_in; -output [WIDTH-1:0] shift_out; -output [WIDTH*DEPTH-1:0] data_out; - -// ToDo: experiment with SRL instead of FF-based shift register -// by force or by achieving automatic SRL inference -//UG901 template for SRL inference: -// 32-bit Shift Register -// Rising edge clock -// Active high clock enable -// For-loop based template -// File: shift_registers_1.v -// -//module shift_registers_1 (clk, clken, SI, SO); -//parameter WIDTH = 32; -//input clk, clken, SI; -//output SO; -//reg [WIDTH-1:0] shreg; -// -//integer i; -//always @(posedge clk) -//begin -// if (clken) -// begin -// for (i = 0; i < WIDTH-1; i = i+1) -// shreg[i+1] <= shreg[i]; -// shreg[0] <= SI; -// end -//end -//assign SO = shreg[WIDTH-1]; -//endmodule - -reg [WIDTH-1:0] data [DEPTH-1:0]; - -assign shift_out = data[DEPTH-1]; - -for (genvar e=0; e<DEPTH; e=e+1) - assign data_out[e*WIDTH +: WIDTH] = data[e]; - -always @ (posedge CLK) begin - if (shift_enable) begin - for (integer i=DEPTH-1; i>0; i=i-1) - data[i] <= data[i-1]; - data[0] <= shift_in; - end -end -endmodule //reg_buffer - -module $TOP_MODULE_NAME$_ram_buffer -#( - parameter WIDTH = 1, - parameter DEPTH = 1 -) -( - CLK, - RST, - shift_enable, - shift_in, - shift_out -); - -input CLK, RST, shift_enable; -input [WIDTH-1:0] shift_in; -output [WIDTH-1:0] shift_out; - -reg [WIDTH-1:0] out_reg; -assign shift_out = out_reg; - -integer addr_w, addr_r; //todo: minimize width (as reg), make r addr depend on w - -$RAM_STYLE$ reg [WIDTH-1:0] ram [DEPTH-1:0]; - -always @(posedge CLK) begin - if (RST == 1'b0) begin - addr_w <= 0; - addr_r <= 1; - end else begin - if (shift_enable) begin - ram[addr_w] <= shift_in; - out_reg <= ram[addr_r]; - - if (addr_w == DEPTH-1) - addr_w <= 0; - else - addr_w <= addr_w + 1; - - if (addr_r == DEPTH-1) - addr_r <= 0; - else - addr_r <= addr_r + 1; - end - end -end -endmodule //ram_buffer - -module $TOP_MODULE_NAME$_wb -#( - parameter IN_WIDTH = 1, //bit-width*C*MMV_in - parameter OUT_ELEM_WIDTH = 1, //bit-width*C - parameter OUT_WIDTH = 1, //bit-width*C*MMV_out - parameter BUFFER_ELEM_TOTAL = 1 -) -( - CLK, - RST, - data_in, - shift_enable, - data_out -); - -input CLK, RST; -input [IN_WIDTH-1:0] data_in; -input shift_enable; -output [OUT_WIDTH-1:0] data_out; - -//Input REG to enable simultaneous R/W -reg [IN_WIDTH-1:0] reg_input; - -$GENERATE_REG_FIFOS$ - -$GENERATE_BRAM_FIFOS$ - -//Fixed interconnect between linear buffers -$GENERATE_BUFFER_CONNECTION$ - -//Fixed REG FIFO <-> output mapping -$GENERATE_OUTPUT_MAPPING$ - -//input register logic -integer i; -always @ (posedge CLK) begin - if (shift_enable) begin - reg_input <= data_in; - end -end - -endmodule //window_buffer - -module $TOP_MODULE_NAME$_impl ( - ap_clk, - ap_rst_n, - in0_V_V_TDATA, - in0_V_V_TVALID, - in0_V_V_TREADY, - out_V_V_TDATA, - out_V_V_TVALID, - out_V_V_TREADY -); - -parameter BIT_WIDTH = $BIT_WIDTH$; -parameter SIMD = $SIMD$; //assuming SIMD = C for this implementation style -parameter MMV_IN = $MMV_IN$; -parameter MMV_OUT = $MMV_OUT$; -parameter BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; -parameter BUF_OUT_ELEM_WIDTH = BIT_WIDTH * SIMD; -parameter BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; -parameter CYCLES_TOTAL = $CYCLES_TOTAL$; -parameter BUF_ELEM_TOTAL = $BUF_ELEM_TOTAL$; - -//IO ports -input ap_clk; -input ap_rst_n; -input [BUF_IN_WIDTH-1:0] in0_V_V_TDATA; -input in0_V_V_TVALID; -output in0_V_V_TREADY; -output [BUF_OUT_WIDTH-1:0] out_V_V_TDATA; -output out_V_V_TVALID; -input out_V_V_TREADY; - -//main buffer instantiation -wire [BUF_IN_WIDTH-1:0] window_buffer_in; -wire [BUF_OUT_WIDTH-1:0] window_buffer_out; -wire window_buffer_shift_enable; -$TOP_MODULE_NAME$_wb -#( - .IN_WIDTH(BUF_IN_WIDTH), - .OUT_ELEM_WIDTH(BUF_OUT_ELEM_WIDTH), - .OUT_WIDTH(BUF_OUT_WIDTH), - .BUFFER_ELEM_TOTAL(BUF_ELEM_TOTAL) -) -window_buffer_inst -( - .CLK(ap_clk), - .RST(ap_rst_n), - .data_in(window_buffer_in), - .shift_enable(window_buffer_shift_enable), - .data_out(window_buffer_out) -); - -integer cycle; //main cycle counter (where either read/write/both happen, resets for each image) -wire read_cmd; -wire write_cmd; -reg write_done; //keep track if W of current cycle was already completed, but we still wait on a R in the same cycle - -wire controller_reset; -wire controller_advance; - -$TOP_MODULE_NAME$_controller -controller_inst -( - .CLK(ap_clk), - .RST(controller_reset), - .advance(controller_advance), - .cmd_read(read_cmd), - .cmd_write(write_cmd) -); - -wire write_blocked; -assign write_blocked = write_cmd && !out_V_V_TREADY && !write_done; - -wire read_ok; -// with transition to next cycle: -// want to read can read source is ready (waiting on VALID allowed) -assign read_ok = read_cmd && !write_blocked && in0_V_V_TVALID; - -wire write_ok; -// with transition to next cycle: -// output is VALID sink is ready sink has already read (we are waiting on source) -assign write_ok = write_cmd && (out_V_V_TREADY || write_done); - -wire advance; -// includes waiting on W if W-only cycle: wait only on W no R/W to wait for -assign advance = read_ok || (!read_cmd && write_ok) || (!read_cmd && !write_cmd); - -//assign buffer control -//todo: if mmv_out < k: might not shift and/or write for multiple read_cmd cycles -assign window_buffer_shift_enable = advance; - -assign controller_reset = !ap_rst_n || ((cycle == CYCLES_TOTAL-1) && advance); -assign controller_advance = advance; - -//assign I/O ports -assign window_buffer_in = in0_V_V_TDATA; -assign out_V_V_TDATA = window_buffer_out; -assign in0_V_V_TREADY = ap_rst_n && read_ok; //only asserted if data is available and we can store it (allowed) -assign out_V_V_TVALID = ap_rst_n && write_cmd && !write_done; //only asserted if we have data available and it has not been read yet (don't wait for READY from sink) - -//main process for advancing cycle count -always @ (posedge ap_clk) begin - if (ap_rst_n == 1'b0) begin - cycle <= 0; - end else begin - if (advance) begin - write_done <= 1'b0; //reset flag - - //count cycle (completed R or W or both (depending on current cycle)) - if (cycle == CYCLES_TOTAL-1) - cycle <= 0; - else - cycle <= cycle+1; - - end else if (write_ok) // successful W in this cycle, but R still outstanding - write_done <= 1'b1; //write can happen even if read is blocked, but only for the current cycle! - end -end - -endmodule //TOP_MODULE_NAME_impl diff --git a/finn-rtllib/swg/swg_template_wrapper.v b/finn-rtllib/swg/swg_template_wrapper.v index be5a93b9e63525f79fd37c3d10b4f9828c5bf98e..1b470817d61af4140141d8478c4ae0d538678ad7 100644 --- a/finn-rtllib/swg/swg_template_wrapper.v +++ b/finn-rtllib/swg/swg_template_wrapper.v @@ -11,10 +11,13 @@ module $TOP_MODULE_NAME$ ( out_V_TREADY ); +// top-level parameters (set via code-generation) parameter BIT_WIDTH = $BIT_WIDTH$; parameter SIMD = $SIMD$; parameter MMV_IN = $MMV_IN$; parameter MMV_OUT = $MMV_OUT$; + +// derived constants parameter BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; parameter BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; @@ -30,7 +33,12 @@ output out_V_TVALID; input out_V_TREADY; $TOP_MODULE_NAME$_impl -#() +#( + .BIT_WIDTH(BIT_WIDTH), + .SIMD(SIMD), + .MMV_IN(MMV_IN), + .MMV_OUT(MMV_OUT) +) impl ( .ap_clk(ap_clk), diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py index f1e0f53a7a192b6a4eb5aa86583a8636ff79b0b8..98351942b9b4abd1568c9d465710a181e9cab86c 100755 --- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py +++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py @@ -49,128 +49,17 @@ except ModuleNotFoundError: # - Addressable cyclic buffer: to be used when out_width <= in_width # - Parallel registers + line buffers: to be used when out_width > in_width # Supports non-square, 1D, strided, dilated, and depthwise convolutions. -# Note: the actual data layout produced is different for depthwise and non-depthwise ops: +# Note: the actual data layout produced is different for depthwise and non-depthwise: # * non-depthwise SWG: (1, OFMDim_H, OFMDim_W, K_H, K_W, IFMChannels/SIMD, SIMD) # * depthwise SWG: (1, OFMDim_H, OFMDim_W, IFMChannels/SIMD, K_H, K_W, SIMD) - -# helper functions for parallel mode buffer scheduling (to be superseded by improved implementation): - - -def schedule_append(schedule, op): - if len(schedule) > 0 and schedule[-1][1] == op: - count, op_ = schedule[-1] - schedule[-1] = (count + 1, op_) - else: - schedule.append((1, op)) - return schedule - - -def schedule_map_cmds(seq): - mapping = { - "w": ("1'b1", "1'b0"), - "r": ("1'b0", "1'b1"), - "wr": ("1'b1", "1'b1"), - "n": ("1'b0", "1'b0"), - } - if seq: - if len(seq) == 2: - return (seq[0], mapping[seq[1]], 0, mapping["n"]) - if len(seq) == 4: - return (seq[0], mapping[seq[1]], seq[2], mapping[seq[3]]) - else: - return (0, mapping["n"], 0, mapping["n"]) - - -def schedule_map_controller(schedule): - # Experimental implementation to map fixed controller loop structure to R/W schedule by analyzing - # the access pattern given by Im2Col, rather than direct computation. - # TODO: Probably replace this with a directly-computed schedule, similar to the default implementation style. - - # leave first sequence (pre-load) as is - start_sequence = schedule[0] - loop_sequence_1_counter = 1 - loop_sequence_1 = schedule[1] - loop_counter = 0 - loop_sequence_2 = None - end_sequence = None - - i = 2 - if i < len(schedule): - loop_sequence_1 += schedule[i] - i += 1 - while i + 1 < len(schedule): - candidate = schedule[i] + schedule[i + 1] - if candidate == loop_sequence_1: - loop_sequence_1_counter += 1 - i += 2 - else: - break - - if i < len(schedule): - loop_sequence_2 = schedule[i] - i += 1 - if i + 1 < len(schedule): - candidate = schedule[i] + schedule[i + 1] - if candidate != loop_sequence_1: - loop_sequence_2 += schedule[i] - i -= 1 - loop_sequence_total_len = ( - int(len(loop_sequence_2) / 2) - ) + loop_sequence_1_counter * (int(len(loop_sequence_1) / 2)) - loop_sequence_total = ( - loop_sequence_2 + loop_sequence_1_counter * loop_sequence_1 - ) - while i + loop_sequence_total_len < len(schedule): - candidate = schedule[i] - for x in range(i + 1, i + loop_sequence_total_len): - candidate += schedule[x] - - if candidate == loop_sequence_total: - loop_counter += 1 - i += loop_sequence_total_len - else: - break - else: - if i < len(schedule): - end_sequence = loop_sequence_2 + schedule[i] - i += 1 - loop_sequence_2 = None - else: - end_sequence = loop_sequence_2 - loop_sequence_2 = None - - if i < len(schedule): - end_sequence = schedule[i] - i += 1 - if i < len(schedule): - end_sequence = end_sequence + schedule[i] - i += 1 - - assert len(start_sequence) == 1 * 2, "ERROR: invalid start sequence" - assert len(loop_sequence_1) == 2 * 2, "ERROR: invalid loop 1 sequence" - if loop_sequence_2: - assert len(loop_sequence_2) <= 2 * 2, "ERROR: invalid loop 2 sequence" - if end_sequence: - assert len(end_sequence) <= 2 * 2, "ERROR: invalid end sequence" - assert i == len(schedule), "ERROR: schedule could not be compacted %d / %d" % ( - i, - len(schedule), - ) - - return ( - start_sequence, - loop_counter, - loop_sequence_1_counter, - loop_sequence_1, - loop_sequence_2, - end_sequence, - ) +# NOTE: "Parallel" implementation style not yet implemented in this version! class ConvolutionInputGenerator_rtl(HLSCustomOp): """Class that does not correspond to one of the finn-hlslib ConvolutionInputGenerator - (sliding window) function variants! ...""" + (sliding window) function variants. Generates an RTL ConvolutionInputGenerator + implementation based on (System-)Verilog templates.""" def __init__(self, onnx_node): super().__init__(onnx_node) @@ -216,15 +105,9 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") ifm_ch = self.get_nodeattr("IFMChannels") simd = self.get_nodeattr("SIMD") - M = self.get_nodeattr("M") assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" wf = int(ifm_ch / simd) - # folded_ishape = (1, ifm_dim_h, ifm_dim_w, wf, simd) - # round up to support ifm_dim % M != 0 - if ifm_dim_w == 1: - folded_ishape = (1, math.ceil(ifm_dim_h / M), ifm_dim_w, wf, int(simd * M)) - else: - folded_ishape = (1, ifm_dim_h, math.ceil(ifm_dim_w / M), wf, int(simd * M)) + folded_ishape = (1, ifm_dim_h, ifm_dim_w, wf, simd) return folded_ishape def get_normal_output_shape(self): @@ -246,30 +129,13 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): stride_h, stride_w = self.get_nodeattr("Stride") dilation_h, dilation_w = self.get_nodeattr("Dilation") simd = self.get_nodeattr("SIMD") - M = self.get_nodeattr("M") pad = 0 ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, pad, dilation_h) ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, pad, dilation_w) assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" if self.get_nodeattr("parallel_window"): wf = int((ifm_ch) // simd) - # folded_oshape = (1, ofm_dim_h, ofm_dim_w, wf, k_h * k_w * simd) - if ofm_dim_w == 1: - folded_oshape = ( - 1, - int(ofm_dim_h / M), - ofm_dim_w, - wf, - k_h * k_w * int(simd * M), - ) - else: - folded_oshape = ( - 1, - ofm_dim_h, - int(ofm_dim_w / M), - wf, - k_h * k_w * int(simd * M), - ) + folded_oshape = (1, ofm_dim_h, ofm_dim_w, wf, k_h * k_w * simd) else: wf = int((k_h * k_w * ifm_ch) // simd) folded_oshape = (1, ofm_dim_h, ofm_dim_w, wf, simd) @@ -303,9 +169,8 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): ibits = self.get_input_datatype().bitwidth() simd = self.get_nodeattr("SIMD") ifm_ch = self.get_nodeattr("IFMChannels") - M = self.get_nodeattr("M") assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" - in_width = simd * ibits * M + in_width = simd * ibits return in_width def get_outstream_width(self): @@ -327,9 +192,28 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): num_output_elems = np.prod(folded_oshape[:-1]) return num_output_elems + def get_1d_conv_attrs_normalized(self): + # normalize FM dimensions so that: + # [H, W] = [Y, X] = [1, D] or [D, 1] are always mapped to [1, D]. + # The dummy ('1') dimension is the Y-dimension. + ifm_ch = self.get_nodeattr("IFMChannels") + k = self.get_nodeattr("ConvKernelDim") + ifm_dim = self.get_nodeattr("IFMDim") + ofm_dim = self.get_nodeattr("OFMDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + + if ifm_dim[1] == 1: + ifm_dim = ifm_dim[::-1] + ofm_dim = ofm_dim[::-1] + k = k[::-1] + stride = stride[::-1] + dilation = dilation[::-1] + + return (ifm_ch, ifm_dim, ofm_dim, k, stride, dilation) + def get_exp_cycles(self): simd = self.get_nodeattr("SIMD") - m = self.get_nodeattr("M") ifm_ch = self.get_nodeattr("IFMChannels") k = self.get_nodeattr("ConvKernelDim") ifm_dim = self.get_nodeattr("IFMDim") @@ -342,84 +226,137 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): k_h, k_w = k stride_h, stride_w = stride dilation_h, dilation_w = dilation - k_h, k_w = k - stride_h, stride_w = stride - dilation_h, dilation_w = dilation - impl_style = self.select_impl_style() - if impl_style == "parallel": - exp_cycles = self.get_number_input_values() + 2 + channel_factor = int(ifm_ch / simd) + + if ifm_dim_h == 1 or ifm_dim_w == 1: + # 1D case + ( + ifm_ch, + [ifm_dim_h, ifm_dim_w], + [ofm_dim_h, ofm_dim_w], + [k_h, k_w], + [stride_h, stride_w], + [dilation_h, dilation_w], + ) = self.get_1d_conv_attrs_normalized() + + if depthwise: + exp_cycles = ( + +ofm_dim_w * k_w * channel_factor + + channel_factor * (k_w - 1) * (stride_w - 1) + - (k_w - 1) + + 2 + ) + else: + exp_cycles = ofm_dim_w * k_w * channel_factor + 2 else: - # based on 2D HLS SWG estimate - # FIXME: increase accuracy for newly supported parameter scenarios - cycles_write_block = (ofm_dim_w * k_w * k_h * (ifm_ch / simd)) / 1 - cycles_read_block = stride_w * ifm_dim_w * (ifm_ch / simd) + # 2D case + buffer_min_size = ( + (k_h - 1) * dilation_h * ifm_dim_w + (k_w - 1) * dilation_w + 1 + ) * channel_factor + cycles_write_block = ofm_dim_w * k_w * k_h * channel_factor + cycles_read_block = stride_w * ifm_dim_w * channel_factor max_cycles = max(cycles_write_block, cycles_read_block) - exp_cycles = ( - ifm_dim_w * k_h * dilation_h * (ifm_ch / simd) + ofm_dim_h * max_cycles - ) + if depthwise: + max_cycles += ofm_dim_w * (stride_w - 1) * (channel_factor - 1) + exp_cycles = buffer_min_size + ofm_dim_h * max_cycles # initial buffering + if depthwise: + exp_cycles += (stride_h - 1) * ifm_dim_w * channel_factor return int(exp_cycles) def bram_estimation(self): simd = self.get_nodeattr("SIMD") ram_style = self.get_nodeattr("ram_style") - impl_style = self.select_impl_style() # call codegen preparation to populate self.buffer_depth if impl_style == "default": - template_path, code_gen_dict = self.prepare_codegen_default() - elif impl_style == "parallel": - template_path, code_gen_dict = self.prepare_codegen_parallel() + self.prepare_codegen_default() + else: + raise Exception("Requested impl. style not implemented") + # NOTE: Actual BRAM usage might be lower in some cases. + # This does not account for the exact Vivado behavior yet. buffer_width = simd * self.get_input_datatype().bitwidth() buffer_depth = self.buffer_depth - if ram_style == "block" or ram_style == "auto": - ram_depth = buffer_depth - if ram_depth <= 512: + if buffer_depth <= 512: ram_width = 36 - elif ram_depth <= 1024: + elif buffer_depth <= 1024: ram_width = 18 - elif ram_depth <= 2048: + elif buffer_depth <= 2048: ram_width = 9 - elif ram_depth <= 4096: + elif buffer_depth <= 4096: ram_width = 4 - elif ram_depth <= 8192: + elif buffer_depth <= 8192: ram_width = 2 else: ram_width = 1 ram_cascade_depth = math.ceil(buffer_depth / 16384) ram_cascade_width = math.ceil(buffer_width / ram_width) + cascade_savings = 0 + if buffer_depth > 16384: + remainder_depth = buffer_depth % 16384 + if remainder_depth <= 512: + remainder_width = 36 + elif remainder_depth <= 1024: + remainder_width = 18 + elif remainder_depth <= 2048: + remainder_width = 9 + elif remainder_depth <= 4096: + remainder_width = 4 + elif remainder_depth <= 8192: + remainder_width = 2 + else: + remainder_width = 1 - return int(ram_cascade_depth * ram_cascade_width) + remainder_cascade_width = math.ceil(buffer_width / remainder_width) + cascade_savings = ram_cascade_width - remainder_cascade_width + + return int(ram_cascade_depth * ram_cascade_width - cascade_savings) else: return 0 def lut_estimation(self): simd = self.get_nodeattr("SIMD") ram_style = self.get_nodeattr("ram_style") - impl_style = self.select_impl_style() # call codegen preparation to populate self.buffer_depth if impl_style == "default": - template_path, code_gen_dict = self.prepare_codegen_default() - elif impl_style == "parallel": - template_path, code_gen_dict = self.prepare_codegen_parallel() + self.prepare_codegen_default() + else: + raise Exception("Requested impl. style not implemented") buffer_width = simd * self.get_input_datatype().bitwidth() buffer_depth = self.buffer_depth - if ram_style == "distributed": - ram_luts = int(buffer_width * math.ceil(buffer_depth / 32)) + ram_luts = int(buffer_width * math.ceil(buffer_depth / 38)) else: ram_luts = 0 return 300 + ram_luts def uram_estimation(self): - # TODO: implement URAM estimation - return 0 + simd = self.get_nodeattr("SIMD") + ram_style = self.get_nodeattr("ram_style") + impl_style = self.select_impl_style() + # call codegen preparation to populate self.buffer_depth + if impl_style == "default": + self.prepare_codegen_default() + else: + raise Exception("Requested impl. style not implemented") + + buffer_width = simd * self.get_input_datatype().bitwidth() + buffer_depth = self.buffer_depth + + if ram_style == "ultra": + ram_depth = 4096 + ram_width = 72 + ram_cascade_depth = math.ceil(buffer_depth / ram_depth) + ram_cascade_width = math.ceil(buffer_width / ram_width) + return int(ram_cascade_depth * ram_cascade_width) + else: + return 0 def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") @@ -427,10 +364,9 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): exp_ishape = self.get_normal_input_shape() exp_oshape = self.get_normal_output_shape() folded_ishape = self.get_folded_input_shape() - folded_oshape = self.get_folded_output_shape() if mode == "cppsim": - raise Exception("""cppsim not possible for RTL SWG""".format(mode)) + raise Exception("cppsim not possible for RTL SWG") elif mode == "rtlsim": code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") else: @@ -443,11 +379,9 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): inp = context[node.input[0]] assert str(inp.dtype) == "float32", "Input datatype is not float32" - # disable this check to allow for IFMdim % M != 0 case (see below) where input comes from MMV-output capable node - # assert ( - # inp.shape == exp_ishape - # ), """Input shape doesn't - # match expected shape (1, ifm_dim, ifm_dim, ifm_ch).""" + assert ( + inp.shape == exp_ishape + ), """Input shape doesn't match expected shape (1, ifm_dim, ifm_dim, ifm_ch).""" if self.get_input_datatype() == DataType["BIPOLAR"]: # store bipolar activations as binary inp = (inp + 1) / 2 @@ -455,20 +389,6 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): else: export_idt = self.get_input_datatype() - # pad test input stream to work when IFMdim % M != 0 - # during normal operation, the AXI Stream should not care, in the last cycle garbage elements are read but not used - # TODO: only works for 1D case - mmv_stream_padding_px = int( - (np.prod(folded_ishape) - np.prod(inp.shape)) / exp_ishape[-1] - ) - if exp_ishape[2] == 1: - inp = np.pad( - inp, ((0, 0), (0, mmv_stream_padding_px), (0, 0), (0, 0)), "constant" - ) - else: - inp = np.pad( - inp, ((0, 0), (0, 0), (0, mmv_stream_padding_px), (0, 0)), "constant" - ) # reshape input into folded form inp = inp.reshape(folded_ishape) # make copy before saving array @@ -521,7 +441,6 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): dilation = self.get_nodeattr("Dilation") depthwise = self.get_nodeattr("depthwise") simd = self.get_nodeattr("SIMD") - M = self.get_nodeattr("M") k_h, k_w = k h, w = ifm_dim @@ -532,15 +451,8 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): pad_w = pad[1] + pad[3] out_dim_h = im2col.compute_conv_output_dim(h, k_h, stride_h, pad_h, dilation_h) out_dim_w = im2col.compute_conv_output_dim(w, k_w, stride_w, pad_w, dilation_w) - - if self.get_nodeattr("parallel_window"): - mmv_in = M * 1 - mmv_out = M * k_h * k_w - else: - mmv_in = 1 - mmv_out = 1 - - # compute index/address increments for each nested loop + mmv_in = 1 + mmv_out = 1 channel_factor = int(ifm_ch / simd) # compute minimal buffer length (assuming it holds 1 complete window) @@ -549,7 +461,7 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): ) * channel_factor # add additional buffer space in case of stride > 1 - # this minimizes cycle count, as it allows an earlier pre-load of skipped input elements + # this minimizes cycle count as it allows an earlier pre-load of input elements buffer_actual_size = ( buffer_min_size + max( @@ -588,7 +500,7 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): + 1 ) - # re-use same controller structure -> re-assign address increments for the dw case + # re-use same controller structure -> re-assign address increments if depthwise: addr_incr_end_window_elem = dilation_w * channel_factor addr_incr_end_window_row = ( @@ -633,22 +545,7 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): tail_incr_w = addr_incr_end_window + buffer_min_size - channel_factor tail_incr_h = addr_incr_end_row + buffer_min_size - channel_factor tail_incr_last_window = buffer_min_size - 1 - code_gen_dict["$TAIL_INCR_GENERATION$"] = [ - """ - always @ (counter_loop_kh, counter_loop_w, counter_loop_h) begin - if (counter_loop_kh >= 0) - tail_incr_reg = 1; - else if (counter_loop_w >= 0) - tail_incr_reg = {}; - else if (counter_loop_h >= 0) - tail_incr_reg = {}; - else - tail_incr_reg = {}; - end - """.format( - tail_incr_w, tail_incr_h, tail_incr_last_window - ) - ] + code_gen_dict["$IS_DEPTHWISE$"] = ["1"] else: # depthwise output format is equivalent to non-depthwise if SIMD=C elem_per_window = k_h * k_w * channel_factor @@ -656,20 +553,11 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): tail_incr_w = addr_incr_end_window + buffer_min_size - 1 tail_incr_h = addr_incr_end_row + buffer_min_size - 1 tail_incr_last_window = buffer_min_size - 1 - code_gen_dict["$TAIL_INCR_GENERATION$"] = [ - """ - always @ (counter_loop_w, counter_loop_h) begin - if (counter_loop_w >= 0) - tail_incr_reg = {}; - else if (counter_loop_h >= 0) - tail_incr_reg = {}; - else - tail_incr_reg = {}; - end - """.format( - tail_incr_w, tail_incr_h, tail_incr_last_window - ) - ] + code_gen_dict["$IS_DEPTHWISE$"] = ["0"] + + code_gen_dict["$TAIL_INCR_W$"] = [str(tail_incr_w)] + code_gen_dict["$TAIL_INCR_H$"] = [str(tail_incr_h)] + code_gen_dict["$TAIL_INCR_LAST$"] = [str(tail_incr_last_window)] # support SIMD = C and k_w = 1 cases # for k = [k_h, k_w] = [1, k_w], no adjustment is needed @@ -732,373 +620,42 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): return template_path, code_gen_dict - def prepare_codegen_parallel(self): - # Parallel implementation style for MMV_out = K: - # mix of shift-registers (for parallel read) and line buffers (BRAM or LUTRAM) - # compute a static schedule by analyzing access pattern (from im2col function) - template_path = ( - os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_parallel.sv" - ) - code_gen_dict = {} - + def select_impl_style(self): + simd = self.get_nodeattr("SIMD") + M = self.get_nodeattr("M") ifm_ch = self.get_nodeattr("IFMChannels") - k = self.get_nodeattr("ConvKernelDim") ifm_dim = self.get_nodeattr("IFMDim") stride = self.get_nodeattr("Stride") dilation = self.get_nodeattr("Dilation") - simd = self.get_nodeattr("SIMD") - M = self.get_nodeattr("M") - - k_h, k_w = k - h, w = ifm_dim - n = c = 1 # no need to consider fully-parallel C dimension - in_shape = (n, c, h, w) - pad = [0, 0, 0, 0] + k = self.get_nodeattr("ConvKernelDim") + ifm_dim_h, ifm_dim_w = ifm_dim stride_h, stride_w = stride dilation_h, dilation_w = dilation - in_image = np.empty(in_shape, dtype=int) - in_image_padded = np.pad( - in_image, - ((0, 0), (0, 0), (pad[0], pad[2]), (pad[1], pad[3])), - mode="constant", - constant_values=0, - ) - in_shape_padded = in_image_padded.shape - h_padded = in_shape_padded[2] - w_padded = in_shape_padded[3] - pad_h = pad[0] + pad[2] - pad_w = pad[1] + pad[3] - out_dim_h = im2col.compute_conv_output_dim(h, k_h, stride_h, pad_h, dilation_h) - out_dim_w = im2col.compute_conv_output_dim(w, k_w, stride_w, pad_w, dilation_w) - - if self.get_nodeattr("parallel_window"): - mmv_in = M * 1 - mmv_out = M * k_h * k_w - assert ifm_ch == simd, "Constraint violated: SIMD must be equal to C" - else: - mmv_in = 1 - mmv_out = 1 - assert ifm_ch % simd == 0, "Constraint violated: SIMD must divide C" - - # Out width > In width: Parallel implementation style using registers + line buffers - idx_c, idx_h, idx_w = im2col.get_im2col_indices_nchw( - in_shape, k_h, k_w, pad, stride_h, stride_w, dilation_h, dilation_w - ) - - cols = in_image_padded[:, idx_c, idx_h, idx_w] - cols = cols.transpose(1, 2, 0).reshape(k_h * k_w * c, -1) - # result shape is (k_H*k_W*N, out_dim_H*out_dim_W), convert to NCHW - out_image = cols.reshape(n, c, k_h, k_w, out_dim_h, out_dim_w) - # (N=0,C=1,kh=2,kw=3,H=4,W=5) -> (N=0,H=4,W=5,kh=2,kw=3,C=1) - out_image = out_image.transpose(0, 4, 5, 2, 3, 1) - out_image = out_image.reshape(n, out_dim_h, out_dim_w, k_h * k_w * c) - idx_px = idx_h * w + idx_w # sequential pixel indices - k, cycles = idx_px.shape - output_elements = mmv_out - output_cycles = int(cycles / (mmv_out / k)) - - idx_px = idx_px.transpose() - idx_px = idx_px.reshape(output_cycles, output_elements) - idx_px = idx_px.transpose() - # result: first dim is number of parallel output elements, - # second dim is the input element (pixel in case of SIMD=C) index that each output element outputs per cycle - - buffer = [] - buffer_max_size = 0 - schedule = [] - next_in_px = 0 - oldest_px = 0 - - # compute schedule and buffer read pattern (output driven) - idx_px_relative = idx_px.copy() - output_elem, output_cycles = idx_px_relative.shape - - for x in range(output_cycles): - # load missing inputs into buffer - for y in range(output_elem): - while int(idx_px_relative[y, x]) >= next_in_px: - # load M inputs at once (keep "buffer" list 1D for now, handle actual 2D buffer generation later) - for m in range(M): - buffer.append(next_in_px) - next_in_px += 1 - schedule = schedule_append(schedule, "w") - - # discard unused buffer elements - # FIXME: this is very slow for large feature maps (e.g., 4096x4096) - oldest_px = np.min(idx_px_relative[:, x:]) - # check whether M elements can be shifted out, not just the single oldest one - # while all([buffer[i] < oldest_px for i in range(M)]): - if all([buffer[i] < oldest_px for i in range(M)]): - # M buffer elements are shifted out at once - for m in range(M): - buffer.pop(0) - - # adjust relative buffer index of current x (according to last discarded buffer elements) - for y in range(output_elem): - idx_px_relative[y, x] -= oldest_px - - # read from buffer - # + simultaneously load next pixel(s) into buffer if there are any left - if next_in_px > (h_padded * w_padded - 1): - # read only (append above) - schedule = schedule_append(schedule, "r") - else: - # load M inputs at once - for m in range(M): - buffer.append(next_in_px) - next_in_px += 1 - schedule = schedule_append(schedule, "wr") - - # record max needed buffer depth - if len(buffer) > buffer_max_size: - buffer_max_size = len(buffer) - - # insert dummy write operations for data at the input FM tail-end that is never read (e.g. in case of stride > 1) - while next_in_px <= (h_padded * w_padded - 1): - next_in_px += 1 - schedule = schedule_append(schedule, "w") - - # add 1 extra cycle after final READ+WRITE cycle for transition b/w feature maps - if schedule[-1][1] == "wr": - schedule_append(schedule, "n") - - # find buffer access patterns - buffer_access_patterns = [] - for x in range(output_cycles): - if idx_px_relative[:, x].tolist() not in buffer_access_patterns: - buffer_access_patterns.append(idx_px_relative[:, x].tolist()) - - ### determine buffer partitioning into REG FIFOs (parallel access) and BRAM FIFOs (line buffers) - # TODO: this part doesn't fully account for M>1 for 2D buffers yet - REG_BRAM_THRESHOLD = 8 - # how many "unused" registers are allowed between buffer positions that will be accessed in parallel - # example: - # 0: only consecutive access patterns will be implemented in regs, rest in (LUTRAM/BRAM) line buffers - # 2: [0, 3, 6] access pattern is still allowed and will be implemented with one 7-position shift reg - - code_gen_dict["$BUF_ELEM_TOTAL$"] = [str(buffer_max_size)] - self.buffer_depth = buffer_max_size # for resource estimation + k_h, k_w = k + kernel_width = (k_w - 1) * dilation_w + 1 # incl. dilation + kernel_height = (k_h - 1) * dilation_h + 1 # incl. dilation + # check for valid configuration assert ( - len(buffer_access_patterns) == 1 - ), "ERROR: Buffer access pattern is not static" - buf_static_access_pattern = buffer_access_patterns[0] - reg_fifos = [] - reg_fifos_depth = [] - bram_fifos = [] - bram_fifos_depth = [] - current = [] - for i in range(len(buf_static_access_pattern)): - access_idx = buf_static_access_pattern[i] - if len(current) == 0: - current.append(access_idx) - else: - # assume non-decreasing index order in access pattern - # TODO: this assumption does not hold for M>1 for the 2D case - distance = access_idx - max(current) - if not (distance - 1 > REG_BRAM_THRESHOLD): - for i in range(distance - 1): - # insert dummy into REG FIFO (not read as part of window) - current.append(-1) - # assign this access to same REG FIFO as previous one - current.append(access_idx) - else: - # assign skipped accesses to new BRAM FIFO - bram_fifos.append([-1] * (distance - 1)) - bram_fifos_depth.append( - math.ceil((distance - 1) / M) - ) # really ceil? - # start with new REG FIFO - reg_fifos.append(current) - # reg_fifos_depth.append(math.ceil((max(current)+1)/M)) # allows for MMV in the 1D case - reg_fifos_depth.append(len(current)) - current = [] - current.append(access_idx) - reg_fifos.append(current) - # reg_fifos_depth.append(math.ceil((max(current)+1)/M)) # allows for MMV in the 1D case - reg_fifos_depth.append(len(current)) - - code_gen_dict["$GENERATE_REG_FIFOS$"] = [] - for i in range(len(reg_fifos)): - code_gen_dict["$GENERATE_REG_FIFOS$"].append( - """ - wire [IN_WIDTH-1:0] reg_fifo_{id}_in; - wire [IN_WIDTH-1:0] reg_fifo_{id}_out; - wire [IN_WIDTH*{len}-1:0] reg_fifo_{id}; - {name}_reg_buffer - #( - .WIDTH(IN_WIDTH), - .DEPTH({len}) - ) - reg_buffer_inst_{id} - ( - .CLK(CLK), - .shift_enable(shift_enable), - .shift_in(reg_fifo_{id}_in), - .shift_out(reg_fifo_{id}_out), - .data_out(reg_fifo_{id}) - );""".format( - name=self.get_verilog_top_module_name(), - id=i, - len=reg_fifos_depth[i], - ) - ) - - code_gen_dict["$GENERATE_BRAM_FIFOS$"] = [] - for i in range(len(bram_fifos)): - code_gen_dict["$GENERATE_BRAM_FIFOS$"].append( - """ - wire [IN_WIDTH-1:0] bram_fifo_{id}_in; - wire [IN_WIDTH-1:0] bram_fifo_{id}_out; - {name}_ram_buffer - #( - .WIDTH(IN_WIDTH), - .DEPTH({len}) - ) - ram_buffer_inst_{id} - ( - .CLK(CLK), - .RST(RST), - .shift_enable(shift_enable), - .shift_in(bram_fifo_{id}_in), - .shift_out(bram_fifo_{id}_out) - );""".format( - name=self.get_verilog_top_module_name(), - id=i, - len=bram_fifos_depth[i], - ) - ) - - code_gen_dict["$GENERATE_OUTPUT_MAPPING$"] = [] - out_idx = mmv_out - 1 - for fifo_id, reg_fifo in enumerate(reg_fifos): - for fifo_idx, access_idx in enumerate(reg_fifo): - if access_idx != -1: - code_gen_dict["$GENERATE_OUTPUT_MAPPING$"].append( - "assign data_out[OUT_ELEM_WIDTH*{out_idx}+:OUT_ELEM_WIDTH] = reg_fifo_{fifo_id}[{access_idx}*{mmv}*OUT_ELEM_WIDTH+OUT_ELEM_WIDTH*{mmv_idx}+:OUT_ELEM_WIDTH];".format( - out_idx=out_idx, - fifo_id=fifo_id, - access_idx=reg_fifos_depth[fifo_id] - - 1 - - int((max(reg_fifo) - access_idx) / M), - mmv_idx=(max(reg_fifo) - access_idx) % M, - mmv=M, - ) - ) - # reversal: out_idx=0 -> oldest buffer element -> highest access_idx - out_idx = out_idx - 1 - assert out_idx == -1, "ERROR: Not all output vector elements connected" - - code_gen_dict["$GENERATE_BUFFER_CONNECTION$"] = [] - for i in range(len(reg_fifos)): - if i == 0: - # first FIFO containing newest elements -> input comes from input reg - code_gen_dict["$GENERATE_BUFFER_CONNECTION$"].append( - """assign reg_fifo_{fifo_id}_in = reg_input;""".format( - fifo_id=i, - ) - ) - else: - # other REG FIFOs -> input comes from connected BRAM FIFO (line buffer) - input_fifo_id = i - 1 - code_gen_dict["$GENERATE_BUFFER_CONNECTION$"].append( - """assign reg_fifo_{fifo_id}_in = bram_fifo_{input_fifo_id}_out;""".format( - fifo_id=i, input_fifo_id=input_fifo_id - ) - ) - for i in range(len(bram_fifos)): - input_fifo_id = i - code_gen_dict["$GENERATE_BUFFER_CONNECTION$"].append( - """assign bram_fifo_{fifo_id}_in = reg_fifo_{input_fifo_id}_out;""".format( - fifo_id=i, input_fifo_id=input_fifo_id - ) - ) - - ( - start_sequence, - loop_counter, - loop_sequence_1_counter, - loop_sequence_1, - loop_sequence_2, - end_sequence, - ) = schedule_map_controller(schedule) - - start_sequence = schedule_map_cmds(start_sequence) - loop_sequence_1 = schedule_map_cmds(loop_sequence_1) - loop_sequence_2 = schedule_map_cmds(loop_sequence_2) - end_sequence = schedule_map_cmds(end_sequence) - - cycles_total = 0 - for t in schedule: - cycles_total += t[0] - # add extra cycle if schedule ends on READ - if schedule[-1][1] == "r": - cycles_total += 1 - code_gen_dict["$CYCLES_TOTAL$"] = [str(cycles_total)] - - code_gen_dict["$START_COUNTER$"] = [str(start_sequence[0])] - code_gen_dict["$LOOP_MAIN_COUNTER$"] = [str(loop_sequence_1_counter)] - code_gen_dict["$LOOP_INTER_COUNTER$"] = [str(loop_counter)] - - code_gen_dict["$LOOP_MAIN_1_COUNTER$"] = [str(loop_sequence_1[0])] - code_gen_dict["$LOOP_MAIN_2_COUNTER$"] = [str(loop_sequence_1[2])] - - code_gen_dict["$LOOP_INTER_1_COUNTER$"] = [str(loop_sequence_2[0])] - code_gen_dict["$LOOP_INTER_2_COUNTER$"] = [str(loop_sequence_2[2])] - - code_gen_dict["$LOOP_END_1_COUNTER$"] = [str(end_sequence[0])] - code_gen_dict["$LOOP_END_2_COUNTER$"] = [str(end_sequence[2])] - - code_gen_dict["$READ_CMD_MAP$"] = [ - "{{ {}, {}, {}, {}, {}, {}, {} }}".format( - start_sequence[1][0], - loop_sequence_1[1][0], - loop_sequence_1[3][0], - loop_sequence_2[1][0], - loop_sequence_2[3][0], - end_sequence[1][0], - end_sequence[3][0], - ) - ] - code_gen_dict["$WRITE_CMD_MAP$"] = [ - "{{ {}, {}, {}, {}, {}, {}, {} }}".format( - start_sequence[1][1], - loop_sequence_1[1][1], - loop_sequence_1[3][1], - loop_sequence_2[1][1], - loop_sequence_2[3][1], - end_sequence[1][1], - end_sequence[3][1], - ) - ] + kernel_height <= ifm_dim_h + and kernel_width <= ifm_dim_w + and stride_h <= ifm_dim_h + and stride_w <= ifm_dim_w + ), "Illegal conv configuration: kernel or stride > FM dimension" - code_gen_dict["$SIMD$"] = [str(simd)] - code_gen_dict["$MMV_IN$"] = [str(mmv_in)] - code_gen_dict["$MMV_OUT$"] = [str(mmv_out)] + if k_h == 1 and k_w == 1: + assert simd == ifm_ch, "1x1 Kernel only supported in parallel mode (SIMD=C)" - return template_path, code_gen_dict - - def select_impl_style(self): - ifm_ch = self.get_nodeattr("IFMChannels") - k = self.get_nodeattr("ConvKernelDim") - simd = self.get_nodeattr("SIMD") - M = self.get_nodeattr("M") - - k_h, k_w = k # init folding config if self.get_nodeattr("parallel_window"): - mmv_in = M * 1 + # mmv_in = M * 1 mmv_out = M * k_h * k_w assert ifm_ch == simd, "Constraint violated: SIMD must be equal to C" else: - mmv_in = 1 + # mmv_in = 1 mmv_out = 1 assert ifm_ch % simd == 0, "Constraint violated: SIMD must divide C" - # TODO: check allowed hyperparams - # for 1D case: it does not matter if dummy dim is x or y - # TODO: move/duplicate these checks in corresponding convert_to_hls transformation (?) - # choose implementation style if mmv_out > 1 or (k_h == 1 and k_w == 1): impl_style = "parallel" @@ -1106,6 +663,9 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): else: impl_style = "default" + assert ( + impl_style == "default" + ), "ERROR: Parallel window mode not yet implemented" return impl_style def generate_hdl(self): @@ -1114,12 +674,12 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): # prepare code generation by filling out dictionaries if impl_style == "default": template_path, code_gen_dict = self.prepare_codegen_default() - elif impl_style == "parallel": - template_path, code_gen_dict = self.prepare_codegen_parallel() + else: + raise Exception("Requested impl. style not implemented") # add general parameters to dictionary code_gen_dict["$TOP_MODULE_NAME$"] = [self.get_verilog_top_module_name()] - # save top module name so we can refer to it even after this node has been renamed + # save top module name so we can refer to it after this node has been renamed # (e.g. by GiveUniqueNodeNames(prefix) during MakeZynqProject) self.set_nodeattr("gen_top_module", self.get_verilog_top_module_name()) code_gen_dict["$BIT_WIDTH$"] = [str(self.get_input_datatype().bitwidth())] @@ -1157,7 +717,8 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): ) as f: f.write(template_wrapper) - # set ipgen_path and ip_path so that HLS-Synth transformation and stich_ip transformation do not complain + # set ipgen_path and ip_path so that HLS-Synth transformation + # and stich_ip transformation do not complain self.set_nodeattr("ipgen_path", code_gen_dir) self.set_nodeattr("ip_path", code_gen_dir) @@ -1191,7 +752,6 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): def code_generation_ipi(self): """Constructs and returns the TCL for node instantiation in Vivado IPI.""" - vlnv = self.get_nodeattr("ip_vlnv") code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") cmd = [ diff --git a/src/finn/transformation/fpgadataflow/set_folding.py b/src/finn/transformation/fpgadataflow/set_folding.py index 23943084ab99d6ab880a69975e0b4a49756905a7..5c94272bad52fd6265a9bb0054fae87a3d77b93b 100644 --- a/src/finn/transformation/fpgadataflow/set_folding.py +++ b/src/finn/transformation/fpgadataflow/set_folding.py @@ -109,6 +109,7 @@ class SetFolding(Transformation): "FMPadding_Batch", "ConvolutionInputGenerator", "ConvolutionInputGenerator1D", + "ConvolutionInputGenerator_rtl", ] # these ops are preceded by depthwise SWG and have special behavior, # as explained in the SetFolding docstring @@ -174,6 +175,7 @@ class SetFolding(Transformation): if op_type in [ "ConvolutionInputGenerator", "ConvolutionInputGenerator1D", + "ConvolutionInputGenerator_rtl", ]: depthwise = node_inst.get_nodeattr("depthwise") if depthwise == 0: diff --git a/tests/fpgadataflow/test_convert_to_hls_conv_layer.py b/tests/fpgadataflow/test_convert_to_hls_conv_layer.py index 7dcae82afe29056cccf8d980e2206d6faab07bfb..56438ac6b6c5ac835ca35d9e66073042e467224f 100644 --- a/tests/fpgadataflow/test_convert_to_hls_conv_layer.py +++ b/tests/fpgadataflow/test_convert_to_hls_conv_layer.py @@ -73,6 +73,9 @@ def test_convert_to_hls_conv_layer(conv_config, depthwise, use_rtl_swg, exec_mod if use_rtl_swg and exec_mode == "cppsim": pytest.skip("cppsim not supported for RTL SWG") + if use_rtl_swg and kernel_size == 1: + pytest.skip("1x1 kernel not supported by current RTL SWG") + if depthwise is True: group = out_chn = in_chn conv_param_shape = [out_chn, 1, kernel_size, kernel_size] diff --git a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py index d3ea9d117c88f57c81bb2c26bc059261d1ed49e5..eeeb09329448f546f4a668fde3d32ffaa36f5aaf 100755 --- a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py +++ b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl.py @@ -28,17 +28,14 @@ import pytest -import numpy as np from onnx import TensorProto, helper from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.general.im2col import compute_conv_output_dim -from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.general import GiveUniqueNodeNames from qonnx.util.basic import gen_finn_dt_tensor import finn.core.onnx_exec as oxe -from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer from finn.transformation.fpgadataflow.prepare_ip import PrepareIP from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode @@ -133,11 +130,6 @@ def make_single_slidingwindow_modelwrapper( model.set_tensor_datatype("inp", idt) model.set_tensor_datatype("outp", odt) - # DEBUG - # swg_node = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")[0] - # swg_inst = getCustomOp(swg_node) - # swg_inst.set_nodeattr("rtlsim_trace", "/home/felixj/WD/finn/finn-rtllib/swg/swg_test_trace.vcd") - return model @@ -147,38 +139,24 @@ def prepare_inputs(input_tensor): # input datatype @pytest.mark.parametrize("idt", [DataType["UINT4"]]) - -# @pytest.mark.parametrize( -# "conv_config", -# [ -# [[12,12], [3, 3], [1, 1], [1, 1]], -# [[13,13], [3, 3], [1, 1], [1, 1]], -# [[12,12], [3, 3], [2, 2], [1, 1]], -# [[13,13], [3, 3], [2, 2], [1, 1]], -# ], -# ) # kernel size -@pytest.mark.parametrize("k", [[1, 1], [2, 2], [3, 3], [1, 2], [1, 3]]) +@pytest.mark.parametrize("k", [[2, 2], [3, 3], [1, 3]]) # input dimension -@pytest.mark.parametrize( - "ifm_dim", [[8, 8], [13, 13], [1, 11], [1, 12], [1, 13], [1, 14]] -) +@pytest.mark.parametrize("ifm_dim", [[24, 24], [13, 13], [1, 14]]) # input channels @pytest.mark.parametrize("ifm_ch", [6]) # Stride -@pytest.mark.parametrize("stride", [[1, 1], [2, 2], [1, 2]]) +@pytest.mark.parametrize("stride", [[1, 1], [2, 2]]) # Dilation -@pytest.mark.parametrize("dilation", [[1, 1], [2, 2], [1, 3]]) +@pytest.mark.parametrize("dilation", [[1, 1], [2, 2]]) # depthwise @pytest.mark.parametrize("dw", [0, 1]) - # input channel parallelism ("SIMD") @pytest.mark.parametrize("simd", [1, 2, 3, 6]) # parallel_window enable (MMV_out = M*K) -@pytest.mark.parametrize("parallel_window", [0, 1]) +@pytest.mark.parametrize("parallel_window", [0]) # in/out MMV ("M") @pytest.mark.parametrize("m", [1]) - # Flip dimensions @pytest.mark.parametrize("flip", [False]) @pytest.mark.slow @@ -186,11 +164,6 @@ def prepare_inputs(input_tensor): def test_fpgadataflow_slidingwindow_rtl( idt, k, ifm_dim, ifm_ch, stride, dilation, dw, simd, m, parallel_window, flip ): - # ifm_dim = conv_config[0] - # k = conv_config[1] - # stride = conv_config[2] - # dilation= conv_config[3] - if flip: if ( ifm_dim[0] == ifm_dim[1] @@ -228,7 +201,8 @@ def test_fpgadataflow_slidingwindow_rtl( k_w == 1 and (stride_w != 1 or dilation_w != 1) ): pytest.skip( - "Illegal convolution configuration: stride or dilation defined for unitary kernel dim" + """Illegal convolution configuration: + stride or dilation defined for unitary kernel dim""" ) if k_h == 1 and k_w == 1 and simd != ifm_ch: pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)") @@ -274,17 +248,6 @@ def test_fpgadataflow_slidingwindow_rtl( ) y_expected = oxe.execute_onnx(golden, input_dict)["outp"] - # DEBUG - print("-------expected:") - print(y_expected) - print("--------produced:") - print(y_produced) - - node = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")[0] - inst = getCustomOp(node) - cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") - print("RTLSIM cycles: %d" % cycles_rtlsim) - if dw == 0: assert (y_produced == y_expected).all() else: @@ -294,9 +257,3 @@ def test_fpgadataflow_slidingwindow_rtl( y_expected = y_expected.transpose(0, 1, 2, 4, 3, 5) y_expected = y_expected.reshape(1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w) assert (y_produced == y_expected).all() - - -# exp_cycles_dict = model.analysis(exp_cycles_per_layer) -# exp_cycles = exp_cycles_dict[node.name] -# assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) -# assert exp_cycles != 0