diff --git a/finn-rtllib/swg/swg_template_axilite.v b/finn-rtllib/swg/swg_template_axilite.v new file mode 100644 index 0000000000000000000000000000000000000000..9479c7f80d7d82b27141dbe5abcce442049237bd --- /dev/null +++ b/finn-rtllib/swg/swg_template_axilite.v @@ -0,0 +1,567 @@ + +`timescale 1 ns / 1 ps + +module $TOP_MODULE_NAME$_axilite # +( + // Users to add parameters here + + // User parameters ends + // Do not modify the parameters beyond this line + + // Width of S_AXI data bus + parameter integer C_S_AXI_DATA_WIDTH = 32, + // Width of S_AXI address bus + parameter integer C_S_AXI_ADDR_WIDTH = 6 +) +( + // Users to add ports here + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg0, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg1, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg2, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg3, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg4, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg5, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg6, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg7, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg8, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg9, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg10, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg11, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg12, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg13, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg14, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg15, + + // User ports ends + // Do not modify the ports beyond this line + + // Global Clock Signal + input wire S_AXI_ACLK, + // Global Reset Signal. This Signal is Active LOW + input wire S_AXI_ARESETN, + // Write address (issued by master, acceped by Slave) + input wire [C_S_AXI_ADDR_WIDTH-1 : 0] S_AXI_AWADDR, + // Write channel Protection type. This signal indicates the + // privilege and security level of the transaction, and whether + // the transaction is a data access or an instruction access. + input wire [2 : 0] S_AXI_AWPROT, + // Write address valid. This signal indicates that the master signaling + // valid write address and control information. + input wire S_AXI_AWVALID, + // Write address ready. This signal indicates that the slave is ready + // to accept an address and associated control signals. + output wire S_AXI_AWREADY, + // Write data (issued by master, acceped by Slave) + input wire [C_S_AXI_DATA_WIDTH-1 : 0] S_AXI_WDATA, + // Write strobes. This signal indicates which byte lanes hold + // valid data. There is one write strobe bit for each eight + // bits of the write data bus. + input wire [(C_S_AXI_DATA_WIDTH/8)-1 : 0] S_AXI_WSTRB, + // Write valid. This signal indicates that valid write + // data and strobes are available. + input wire S_AXI_WVALID, + // Write ready. This signal indicates that the slave + // can accept the write data. + output wire S_AXI_WREADY, + // Write response. This signal indicates the status + // of the write transaction. + output wire [1 : 0] S_AXI_BRESP, + // Write response valid. This signal indicates that the channel + // is signaling a valid write response. + output wire S_AXI_BVALID, + // Response ready. This signal indicates that the master + // can accept a write response. + input wire S_AXI_BREADY, + // Read address (issued by master, acceped by Slave) + input wire [C_S_AXI_ADDR_WIDTH-1 : 0] S_AXI_ARADDR, + // Protection type. This signal indicates the privilege + // and security level of the transaction, and whether the + // transaction is a data access or an instruction access. + input wire [2 : 0] S_AXI_ARPROT, + // Read address valid. This signal indicates that the channel + // is signaling valid read address and control information. + input wire S_AXI_ARVALID, + // Read address ready. This signal indicates that the slave is + // ready to accept an address and associated control signals. + output wire S_AXI_ARREADY, + // Read data (issued by slave) + output wire [C_S_AXI_DATA_WIDTH-1 : 0] S_AXI_RDATA, + // Read response. This signal indicates the status of the + // read transfer. + output wire [1 : 0] S_AXI_RRESP, + // Read valid. This signal indicates that the channel is + // signaling the required read data. + output wire S_AXI_RVALID, + // Read ready. This signal indicates that the master can + // accept the read data and response information. + input wire S_AXI_RREADY +); + +// AXI4LITE signals +reg [C_S_AXI_ADDR_WIDTH-1 : 0] axi_awaddr; +reg axi_awready; +reg axi_wready; +reg [1 : 0] axi_bresp; +reg axi_bvalid; +reg [C_S_AXI_ADDR_WIDTH-1 : 0] axi_araddr; +reg axi_arready; +reg [C_S_AXI_DATA_WIDTH-1 : 0] axi_rdata; +reg [1 : 0] axi_rresp; +reg axi_rvalid; + +// Example-specific design signals +// local parameter for addressing 32 bit / 64 bit C_S_AXI_DATA_WIDTH +// ADDR_LSB is used for addressing 32/64 bit registers/memories +// ADDR_LSB = 2 for 32 bits (n downto 2) +// ADDR_LSB = 3 for 64 bits (n downto 3) +localparam integer ADDR_LSB = (C_S_AXI_DATA_WIDTH/32) + 1; +localparam integer OPT_MEM_ADDR_BITS = 3; +//---------------------------------------------- +//-- Signals for user logic register space example +//------------------------------------------------ +//-- Number of Slave Registers 16 +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg0; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg1; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg2; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg3; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg4; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg5; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg6; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg7; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg8; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg9; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg10; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg11; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg12; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg13; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg14; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg15; +wire slv_reg_rden; +wire slv_reg_wren; +reg [C_S_AXI_DATA_WIDTH-1:0] reg_data_out; +integer byte_index; +reg aw_en; + +// I/O Connections assignments + +assign S_AXI_AWREADY = axi_awready; +assign S_AXI_WREADY = axi_wready; +assign S_AXI_BRESP = axi_bresp; +assign S_AXI_BVALID = axi_bvalid; +assign S_AXI_ARREADY = axi_arready; +assign S_AXI_RDATA = axi_rdata; +assign S_AXI_RRESP = axi_rresp; +assign S_AXI_RVALID = axi_rvalid; +// Implement axi_awready generation +// axi_awready is asserted for one S_AXI_ACLK clock cycle when both +// S_AXI_AWVALID and S_AXI_WVALID are asserted. axi_awready is +// de-asserted when reset is low. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_awready <= 1'b0; + aw_en <= 1'b1; + end + else + begin + if (~axi_awready && S_AXI_AWVALID && S_AXI_WVALID && aw_en) + begin + // slave is ready to accept write address when + // there is a valid write address and write data + // on the write address and data bus. This design + // expects no outstanding transactions. + axi_awready <= 1'b1; + aw_en <= 1'b0; + end + else if (S_AXI_BREADY && axi_bvalid) + begin + aw_en <= 1'b1; + axi_awready <= 1'b0; + end + else + begin + axi_awready <= 1'b0; + end + end +end + +// Implement axi_awaddr latching +// This process is used to latch the address when both +// S_AXI_AWVALID and S_AXI_WVALID are valid. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_awaddr <= 0; + end + else + begin + if (~axi_awready && S_AXI_AWVALID && S_AXI_WVALID && aw_en) + begin + // Write Address latching + axi_awaddr <= S_AXI_AWADDR; + end + end +end + +// Implement axi_wready generation +// axi_wready is asserted for one S_AXI_ACLK clock cycle when both +// S_AXI_AWVALID and S_AXI_WVALID are asserted. axi_wready is +// de-asserted when reset is low. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_wready <= 1'b0; + end + else + begin + if (~axi_wready && S_AXI_WVALID && S_AXI_AWVALID && aw_en ) + begin + // slave is ready to accept write data when + // there is a valid write address and write data + // on the write address and data bus. This design + // expects no outstanding transactions. + axi_wready <= 1'b1; + end + else + begin + axi_wready <= 1'b0; + end + end +end + +// Implement memory mapped register select and write logic generation +// The write data is accepted and written to memory mapped registers when +// axi_awready, S_AXI_WVALID, axi_wready and S_AXI_WVALID are asserted. Write strobes are used to +// select byte enables of slave registers while writing. +// These registers are cleared when reset (active low) is applied. +// Slave register write enable is asserted when valid address and data are available +// and the slave is ready to accept the write address and write data. +assign slv_reg_wren = axi_wready && S_AXI_WVALID && axi_awready && S_AXI_AWVALID; + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + slv_reg0 <= 0; + slv_reg1 <= 0; + slv_reg2 <= 0; + slv_reg3 <= 0; + slv_reg4 <= 0; + slv_reg5 <= 0; + slv_reg6 <= 0; + slv_reg7 <= 0; + slv_reg8 <= 0; + slv_reg9 <= 0; + slv_reg10 <= 0; + slv_reg11 <= 0; + slv_reg12 <= 0; + slv_reg13 <= 0; + slv_reg14 <= 0; + slv_reg15 <= 0; + end + else begin + if (slv_reg_wren) + begin + case ( axi_awaddr[ADDR_LSB+OPT_MEM_ADDR_BITS:ADDR_LSB] ) + 4'h0: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 0 + slv_reg0[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h1: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 1 + slv_reg1[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h2: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 2 + slv_reg2[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h3: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 3 + slv_reg3[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h4: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 4 + slv_reg4[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h5: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 5 + slv_reg5[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h6: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 6 + slv_reg6[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h7: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 7 + slv_reg7[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h8: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 8 + slv_reg8[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h9: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 9 + slv_reg9[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hA: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 10 + slv_reg10[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hB: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 11 + slv_reg11[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hC: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 12 + slv_reg12[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hD: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 13 + slv_reg13[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hE: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 14 + slv_reg14[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hF: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 15 + slv_reg15[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + default : begin + slv_reg0 <= slv_reg0; + slv_reg1 <= slv_reg1; + slv_reg2 <= slv_reg2; + slv_reg3 <= slv_reg3; + slv_reg4 <= slv_reg4; + slv_reg5 <= slv_reg5; + slv_reg6 <= slv_reg6; + slv_reg7 <= slv_reg7; + slv_reg8 <= slv_reg8; + slv_reg9 <= slv_reg9; + slv_reg10 <= slv_reg10; + slv_reg11 <= slv_reg11; + slv_reg12 <= slv_reg12; + slv_reg13 <= slv_reg13; + slv_reg14 <= slv_reg14; + slv_reg15 <= slv_reg15; + end + endcase + end + end +end + +// Implement write response logic generation +// The write response and response valid signals are asserted by the slave +// when axi_wready, S_AXI_WVALID, axi_wready and S_AXI_WVALID are asserted. +// This marks the acceptance of address and indicates the status of +// write transaction. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_bvalid <= 0; + axi_bresp <= 2'b0; + end + else + begin + if (axi_awready && S_AXI_AWVALID && ~axi_bvalid && axi_wready && S_AXI_WVALID) + begin + // indicates a valid write response is available + axi_bvalid <= 1'b1; + axi_bresp <= 2'b0; // 'OKAY' response + end // work error responses in future + else + begin + if (S_AXI_BREADY && axi_bvalid) + //check if bready is asserted while bvalid is high) + //(there is a possibility that bready is always asserted high) + begin + axi_bvalid <= 1'b0; + end + end + end +end + +// Implement axi_arready generation +// axi_arready is asserted for one S_AXI_ACLK clock cycle when +// S_AXI_ARVALID is asserted. axi_awready is +// de-asserted when reset (active low) is asserted. +// The read address is also latched when S_AXI_ARVALID is +// asserted. axi_araddr is reset to zero on reset assertion. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_arready <= 1'b0; + axi_araddr <= 32'b0; + end + else + begin + if (~axi_arready && S_AXI_ARVALID) + begin + // indicates that the slave has acceped the valid read address + axi_arready <= 1'b1; + // Read address latching + axi_araddr <= S_AXI_ARADDR; + end + else + begin + axi_arready <= 1'b0; + end + end +end + +// Implement axi_arvalid generation +// axi_rvalid is asserted for one S_AXI_ACLK clock cycle when both +// S_AXI_ARVALID and axi_arready are asserted. The slave registers +// data are available on the axi_rdata bus at this instance. The +// assertion of axi_rvalid marks the validity of read data on the +// bus and axi_rresp indicates the status of read transaction.axi_rvalid +// is deasserted on reset (active low). axi_rresp and axi_rdata are +// cleared to zero on reset (active low). +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_rvalid <= 0; + axi_rresp <= 0; + end + else + begin + if (axi_arready && S_AXI_ARVALID && ~axi_rvalid) + begin + // Valid read data is available at the read data bus + axi_rvalid <= 1'b1; + axi_rresp <= 2'b0; // 'OKAY' response + end + else if (axi_rvalid && S_AXI_RREADY) + begin + // Read data is accepted by the master + axi_rvalid <= 1'b0; + end + end +end + +// Implement memory mapped register select and read logic generation +// Slave register read enable is asserted when valid address is available +// and the slave is ready to accept the read address. +assign slv_reg_rden = axi_arready & S_AXI_ARVALID & ~axi_rvalid; +always @(*) +begin + // Address decoding for reading registers + case ( axi_araddr[ADDR_LSB+OPT_MEM_ADDR_BITS:ADDR_LSB] ) + 4'h0 : reg_data_out <= slv_reg0; + 4'h1 : reg_data_out <= slv_reg1; + 4'h2 : reg_data_out <= slv_reg2; + 4'h3 : reg_data_out <= slv_reg3; + 4'h4 : reg_data_out <= slv_reg4; + 4'h5 : reg_data_out <= slv_reg5; + 4'h6 : reg_data_out <= slv_reg6; + 4'h7 : reg_data_out <= slv_reg7; + 4'h8 : reg_data_out <= slv_reg8; + 4'h9 : reg_data_out <= slv_reg9; + 4'hA : reg_data_out <= slv_reg10; + 4'hB : reg_data_out <= slv_reg11; + 4'hC : reg_data_out <= slv_reg12; + 4'hD : reg_data_out <= slv_reg13; + 4'hE : reg_data_out <= slv_reg14; + 4'hF : reg_data_out <= slv_reg15; + default : reg_data_out <= 0; + endcase +end + +// Output register or memory read data +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_rdata <= 0; + end + else + begin + // When there is a valid read address (S_AXI_ARVALID) with + // acceptance of read address by the slave (axi_arready), + // output the read dada + if (slv_reg_rden) + begin + axi_rdata <= reg_data_out; // register read data + end + end +end + +// Add user logic here +assign cfg_reg0 = slv_reg0; +assign cfg_reg1 = slv_reg1; +assign cfg_reg2 = slv_reg2; +assign cfg_reg3 = slv_reg3; +assign cfg_reg4 = slv_reg4; +assign cfg_reg5 = slv_reg5; +assign cfg_reg6 = slv_reg6; +assign cfg_reg7 = slv_reg7; +assign cfg_reg8 = slv_reg8; +assign cfg_reg9 = slv_reg9; +assign cfg_reg10 = slv_reg10; +assign cfg_reg11 = slv_reg11; +assign cfg_reg12 = slv_reg12; +assign cfg_reg13 = slv_reg13; +assign cfg_reg14 = slv_reg14; +assign cfg_reg15 = slv_reg15; +// User logic ends + +endmodule diff --git a/finn-rtllib/swg/swg_template_default.sv b/finn-rtllib/swg/swg_template_default.sv index 97517438a0c261e4488b74a677a352f9dc51743b..15f1bd75dbe91e3252ef4fc2569623b9f73d3d7a 100644 --- a/finn-rtllib/swg/swg_template_default.sv +++ b/finn-rtllib/swg/swg_template_default.sv @@ -60,11 +60,11 @@ module $TOP_MODULE_NAME$_controller #( state_e State = $INNERMOST_STATE$; state_e state_next; - logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] Counter_loop_h = LOOP_H_ITERATIONS-1; - logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] Counter_loop_w = LOOP_W_ITERATIONS-1; - logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] Counter_loop_kh = LOOP_KH_ITERATIONS-1; - logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] Counter_loop_kw = LOOP_KW_ITERATIONS-1; - logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] Counter_loop_simd = LOOP_SIMD_ITERATIONS-1; + logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] Counter_loop_h = LOOP_H_ITERATIONS; + logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] Counter_loop_w = LOOP_W_ITERATIONS; + logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] Counter_loop_kh = LOOP_KH_ITERATIONS; + logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] Counter_loop_kw = LOOP_KW_ITERATIONS; + logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] Counter_loop_simd = LOOP_SIMD_ITERATIONS; assign addr_incr = ADDR_INCREMENT_MAP[State]; @@ -101,29 +101,29 @@ module $TOP_MODULE_NAME$_controller #( always_ff @ (posedge clk) begin if(!rst_n) begin State <= $INNERMOST_STATE$; - Counter_loop_h <= LOOP_H_ITERATIONS-1; - Counter_loop_w <= LOOP_W_ITERATIONS-1; - Counter_loop_kh <= LOOP_KH_ITERATIONS-1; - Counter_loop_kw <= LOOP_KW_ITERATIONS-1; - Counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; + Counter_loop_h <= LOOP_H_ITERATIONS; + Counter_loop_w <= LOOP_W_ITERATIONS; + Counter_loop_kh <= LOOP_KH_ITERATIONS; + Counter_loop_kw <= LOOP_KW_ITERATIONS; + Counter_loop_simd <= LOOP_SIMD_ITERATIONS; end else if(advance) begin State <= state_next; if (State == $INNERMOST_STATE$) begin if(Counter_loop_simd >= 0) Counter_loop_simd <= Counter_loop_simd-1; else begin - Counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; + Counter_loop_simd <= LOOP_SIMD_ITERATIONS; if(Counter_loop_kw >= 0) Counter_loop_kw <= Counter_loop_kw-1; else begin - Counter_loop_kw <= LOOP_KW_ITERATIONS-1; + Counter_loop_kw <= LOOP_KW_ITERATIONS; if(Counter_loop_kh >= 0) Counter_loop_kh <= Counter_loop_kh-1; else begin - Counter_loop_kh <= LOOP_KH_ITERATIONS-1; + Counter_loop_kh <= LOOP_KH_ITERATIONS; if(Counter_loop_w >= 0) Counter_loop_w <= Counter_loop_w-1; else begin - Counter_loop_w <= LOOP_W_ITERATIONS-1; + Counter_loop_w <= LOOP_W_ITERATIONS; if(Counter_loop_h >= 0) Counter_loop_h <= Counter_loop_h-1; - else Counter_loop_h <= LOOP_H_ITERATIONS-1; + else Counter_loop_h <= LOOP_H_ITERATIONS; end end end diff --git a/finn-rtllib/swg/swg_template_default_dynamic.sv b/finn-rtllib/swg/swg_template_default_dynamic.sv new file mode 100644 index 0000000000000000000000000000000000000000..96bd8cc591ef2df28a2c2ec58d8ce3a56353923a --- /dev/null +++ b/finn-rtllib/swg/swg_template_default_dynamic.sv @@ -0,0 +1,431 @@ +module $TOP_MODULE_NAME$_controller #( + int unsigned LOOP_H_ITERATIONS = $LOOP_H_ITERATIONS$, + int unsigned LOOP_W_ITERATIONS = $LOOP_W_ITERATIONS$, + int unsigned LOOP_KH_ITERATIONS = $LOOP_KH_ITERATIONS$, + int unsigned LOOP_KW_ITERATIONS = $LOOP_KW_ITERATIONS$, + int unsigned LOOP_SIMD_ITERATIONS = $LOOP_SIMD_ITERATIONS$, + + int unsigned CNTR_BITWIDTH, + int unsigned INCR_BITWIDTH, + + bit [INCR_BITWIDTH-1:0] ADDR_INCREMENT_MAP[6] = $ADDR_INCREMENT_MAP$, + + bit IS_DEPTHWISE = $IS_DEPTHWISE$ +)( + input logic clk, + input logic rst_n, + + input logic advance, + output logic [INCR_BITWIDTH-1:0] addr_incr, + output logic [INCR_BITWIDTH-1:0] tail_incr, + + input logic cfg_valid, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_simd, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_kw, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_kh, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_w, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_simd, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_kw, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_kh, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_w, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_w, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_last +); + + // (dynamic) configuration registers + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_simd = LOOP_SIMD_ITERATIONS; + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_kw = LOOP_KW_ITERATIONS; + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_kh = LOOP_KH_ITERATIONS; + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_w = LOOP_W_ITERATIONS; + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_h = LOOP_H_ITERATIONS; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_simd = ADDR_INCREMENT_MAP[1]; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_kw = ADDR_INCREMENT_MAP[2]; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_kh = ADDR_INCREMENT_MAP[3]; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_w = ADDR_INCREMENT_MAP[4]; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_h = ADDR_INCREMENT_MAP[5]; + logic [INCR_BITWIDTH-1:0] Cfg_incr_tail_w = $TAIL_INCR_W$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_tail_h = $TAIL_INCR_H$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_tail_last = $TAIL_INCR_LAST$; + + // configuration reset/set logic + always_ff @ (posedge clk) begin + if(cfg_valid) begin + Cfg_cntr_simd <= cfg_cntr_simd; + Cfg_cntr_kw <= cfg_cntr_kw; + Cfg_cntr_kh <= cfg_cntr_kh; + Cfg_cntr_w <= cfg_cntr_w; + Cfg_cntr_h <= cfg_cntr_h; + Cfg_incr_head_simd <= cfg_incr_head_simd; + Cfg_incr_head_kw <= cfg_incr_head_kw; + Cfg_incr_head_kh <= cfg_incr_head_kh; + Cfg_incr_head_w <= cfg_incr_head_w; + Cfg_incr_head_h <= cfg_incr_head_h; + Cfg_incr_tail_w <= cfg_incr_tail_w; + Cfg_incr_tail_h <= cfg_incr_tail_h; + Cfg_incr_tail_last <= cfg_incr_tail_last; + end + end + + // state and counters + typedef enum logic [2:0] { + STATE_START, + STATE_LOOP_SIMD, + STATE_LOOP_KW, + STATE_LOOP_KH, + STATE_LOOP_W, + STATE_LOOP_H + } state_e; + state_e State = $INNERMOST_STATE$; + state_e state_next; + + logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] Counter_loop_h = LOOP_H_ITERATIONS; + logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] Counter_loop_w = LOOP_W_ITERATIONS; + logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] Counter_loop_kh = LOOP_KH_ITERATIONS; + logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] Counter_loop_kw = LOOP_KW_ITERATIONS; + logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] Counter_loop_simd = LOOP_SIMD_ITERATIONS; + + //assign addr_incr = ADDR_INCREMENT_MAP[State]; + always_comb begin : blkHead + case (State) + 0 : addr_incr = 0; + 1 : addr_incr = Cfg_incr_head_simd; + 2 : addr_incr = Cfg_incr_head_kw; + 3 : addr_incr = Cfg_incr_head_kh; + 4 : addr_incr = Cfg_incr_head_w; + 5 : addr_incr = Cfg_incr_head_h; + endcase + end + + // combinational logic for tail_incr generation + uwire tail_incr_inner_condition = IS_DEPTHWISE? (Counter_loop_kh >= 0) : 0; + always_comb begin : blkTail + if (tail_incr_inner_condition) + tail_incr = 1; + else if (Counter_loop_w >= 0) + tail_incr = Cfg_incr_tail_w; + else if (Counter_loop_h >= 0) + tail_incr = Cfg_incr_tail_h; + else + tail_incr = Cfg_incr_tail_last; + end + + // combinational next state logic + always_comb begin : blkState + state_next = State; + if(State != $INNERMOST_STATE$) state_next = $INNERMOST_STATE$; + else begin + if(Counter_loop_simd < 0) begin + state_next = + (Counter_loop_kw >= 0)? STATE_LOOP_KW : + (Counter_loop_kh >= 0)? STATE_LOOP_KH : + (Counter_loop_w >= 0)? STATE_LOOP_W : + (Counter_loop_h >= 0)? STATE_LOOP_H : + /* else */ STATE_START; + end + end + end : blkState + + // sequential logic + always_ff @ (posedge clk) begin + if(!rst_n) begin + State <= $INNERMOST_STATE$; + Counter_loop_h <= Cfg_cntr_h; + Counter_loop_w <= Cfg_cntr_w; + Counter_loop_kh <= Cfg_cntr_kh; + Counter_loop_kw <= Cfg_cntr_kw; + Counter_loop_simd <= Cfg_cntr_simd; + end + else if(advance) begin + State <= state_next; + if (State == $INNERMOST_STATE$) begin + if(Counter_loop_simd >= 0) Counter_loop_simd <= Counter_loop_simd-1; + else begin + Counter_loop_simd <= Cfg_cntr_simd; + if(Counter_loop_kw >= 0) Counter_loop_kw <= Counter_loop_kw-1; + else begin + Counter_loop_kw <= Cfg_cntr_kw; + if(Counter_loop_kh >= 0) Counter_loop_kh <= Counter_loop_kh-1; + else begin + Counter_loop_kh <= Cfg_cntr_kh; + if(Counter_loop_w >= 0) Counter_loop_w <= Counter_loop_w-1; + else begin + Counter_loop_w <= Cfg_cntr_w; + if(Counter_loop_h >= 0) Counter_loop_h <= Counter_loop_h-1; + else Counter_loop_h <= Cfg_cntr_h; + end + end + end + end + end + end + end + +endmodule : $TOP_MODULE_NAME$_controller + +module $TOP_MODULE_NAME$_cyclic_buffer_addressable #( + int unsigned WIDTH, + int unsigned DEPTH +)( + input logic clk, + input logic rst_n, + + input logic write_enable, + input logic [$clog2(DEPTH)-1:0] write_addr, + input logic [WIDTH-1:0] data_in, + + input logic read_enable, + input logic [$clog2(DEPTH)-1:0] read_addr, // absolute (!) read address of cyclic buffer + output logic [WIDTH-1:0] data_out +); + + $RAM_STYLE$ logic [WIDTH-1:0] Ram[DEPTH]; + logic [WIDTH-1:0] Out = 'x; + always_ff @(posedge clk) begin + if (read_enable) Out <= Ram[read_addr]; + if (write_enable) Ram[write_addr] <= data_in; + end + assign data_out = Out; + +endmodule : $TOP_MODULE_NAME$_cyclic_buffer_addressable + +module $TOP_MODULE_NAME$_impl #( + int BIT_WIDTH, + int SIMD, + int MMV_IN, + int MMV_OUT, + int LAST_READ_ELEM = $LAST_READ_ELEM$, + int LAST_WRITE_ELEM = $LAST_WRITE_ELEM$, + int BUF_ELEM_TOTAL = $BUF_ELEM_TOTAL$, + int ELEM_PER_WINDOW = $ELEM_PER_WINDOW$, + + int unsigned CNTR_BITWIDTH, + int unsigned INCR_BITWIDTH +)( + input logic ap_clk, + input logic ap_rst_n, + + input logic in0_V_V_TVALID, + output logic in0_V_V_TREADY, + input logic [BIT_WIDTH * SIMD * MMV_IN-1:0] in0_V_V_TDATA, + + output logic out_V_V_TVALID, + input logic out_V_V_TREADY, + output logic [BIT_WIDTH * SIMD * MMV_OUT-1:0] out_V_V_TDATA, + + input logic cfg_valid, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_simd, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_kw, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_kh, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_w, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_simd, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_kw, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_kh, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_w, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_w, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_last, + input logic [31:0] cfg_last_read, //todo: reduce bitwidth to $clog2(LAST_READ_ELEM+1) + input logic [31:0] cfg_last_write +); + // derived Constants + localparam int unsigned BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; + localparam int unsigned BUF_OUT_ELEM_WIDTH = BIT_WIDTH * SIMD; + localparam int unsigned BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; + + // (dynamic) configuration registers + logic [31:0] Cfg_last_read = LAST_READ_ELEM; + logic [31:0] Cfg_last_write = LAST_WRITE_ELEM; + + // configuration reset/set logic + always_ff @ (posedge ap_clk) begin + if(cfg_valid) begin + Cfg_last_read <= cfg_last_read; + Cfg_last_write <= cfg_last_write; + end + end + + // main buffer instantiation + uwire [BUF_IN_WIDTH -1:0] window_buffer_in; + uwire [BUF_OUT_WIDTH-1:0] window_buffer_out; + uwire window_buffer_write_enable; + uwire window_buffer_read_enable; + uwire [$clog2(BUF_ELEM_TOTAL)-1:0] window_buffer_write_addr; + uwire [$clog2(BUF_ELEM_TOTAL)-1:0] window_buffer_read_addr; + $TOP_MODULE_NAME$_cyclic_buffer_addressable #( + .WIDTH(BUF_IN_WIDTH), + .DEPTH(BUF_ELEM_TOTAL) + ) window_buffer_inst ( + .clk(ap_clk), + .rst_n(ap_rst_n), + + .write_enable(window_buffer_write_enable), + .write_addr(window_buffer_write_addr), + .data_in(window_buffer_in), + + .read_enable(window_buffer_read_enable), + .read_addr(window_buffer_read_addr), + .data_out(window_buffer_out) + ); + + //controller instantiation + uwire advance_controller; + uwire signed [INCR_BITWIDTH-1:0] addr_incr; + uwire [INCR_BITWIDTH-1:0] tail_incr; + $TOP_MODULE_NAME$_controller #( + .CNTR_BITWIDTH(CNTR_BITWIDTH), + .INCR_BITWIDTH(INCR_BITWIDTH) + ) controller_inst ( + .clk(ap_clk), + .rst_n(ap_rst_n), + .advance(advance_controller), + .addr_incr(addr_incr), + .tail_incr(tail_incr), + + .cfg_valid(cfg_valid), + .cfg_cntr_simd(cfg_cntr_simd), + .cfg_cntr_kw(cfg_cntr_kw), + .cfg_cntr_kh(cfg_cntr_kh), + .cfg_cntr_w(cfg_cntr_w), + .cfg_cntr_h(cfg_cntr_h), + .cfg_incr_head_simd(cfg_incr_head_simd), + .cfg_incr_head_kw(cfg_incr_head_kw), + .cfg_incr_head_kh(cfg_incr_head_kh), + .cfg_incr_head_w(cfg_incr_head_w), + .cfg_incr_head_h(cfg_incr_head_h), + .cfg_incr_tail_w(cfg_incr_tail_w), + .cfg_incr_tail_h(cfg_incr_tail_h), + .cfg_incr_tail_last(cfg_incr_tail_last) + ); + + // Counters/address registers + // Add a sign bit even to (most) unsigned counters and Window_buffer_read_addr_reg, + // so we can use automatic sign extension and simplify calculations w/ signed increment. + // Alternatively, we could manually sign-extend and shave off a bit here or there. + logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] Newest_buffered_elem = -1; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] Current_elem = 0; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] First_elem_next_window = 0; + logic [$clog2(ELEM_PER_WINDOW) -1:0] Position_in_window = 0; + logic [$clog2(BUF_ELEM_TOTAL)+1 -1:0] Window_buffer_read_addr_reg = 0; + logic [$clog2(BUF_ELEM_TOTAL)-1:0] Window_buffer_write_addr_reg = 0; + + // Control signals/registers + uwire read_cmd = + !reading_done && ( // if there is still an input element left to read + Fetching_done || ( // if fetching is done (e.g. for skipped rows at FM end due to stride) + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(First_elem_next_window) && + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(Current_elem) + ) // (over-)write to buffer if oldest buffered element will no longer be needed + ); + uwire read_ok = read_cmd && in0_V_V_TVALID; + uwire reading_done = Newest_buffered_elem == Cfg_last_read; + + uwire fetch_cmd = !($signed(Current_elem) > Newest_buffered_elem) && !write_blocked && !Fetching_done; + logic Fetching_done = 0; + + logic Write_cmd = 0; + logic Writing_done = 0; + uwire write_ok = Write_cmd && out_V_V_TREADY; + uwire write_blocked = Write_cmd && !out_V_V_TREADY;; + + //assign buffer control + assign window_buffer_write_addr = Window_buffer_write_addr_reg; + assign window_buffer_read_addr = Window_buffer_read_addr_reg; + assign window_buffer_write_enable = read_ok; + assign window_buffer_read_enable = fetch_cmd; + assign advance_controller = fetch_cmd; + + //assign I/O ports + assign window_buffer_in = in0_V_V_TDATA; + assign out_V_V_TDATA = window_buffer_out; + assign in0_V_V_TREADY = ap_rst_n && read_ok; //only asserted if data is available and we can store it (allowed) + assign out_V_V_TVALID = ap_rst_n && Write_cmd; //only asserted if we have data available and it has not been read yet (don't wait for READY from sink) + + //main process for advancing counters + always_ff @(posedge ap_clk) begin + if(!ap_rst_n) begin + Newest_buffered_elem <= -1; + Current_elem <= 0; + First_elem_next_window <= 0; + Position_in_window <= 0; + Window_buffer_read_addr_reg <= 0; + Window_buffer_write_addr_reg <= 0; + Fetching_done <= 0; + Write_cmd <= 0; + Writing_done <= 0; + end + else begin + if (read_ok) begin + Window_buffer_write_addr_reg <= (Window_buffer_write_addr_reg == BUF_ELEM_TOTAL-1)? 0 : Window_buffer_write_addr_reg + 1; + Newest_buffered_elem <= Newest_buffered_elem+1; + + if (Newest_buffered_elem == Cfg_last_read-1) begin + Window_buffer_write_addr_reg <= 0; + end + //check if this is the last read cycle (reading_done will be true afterwards) + if ((Newest_buffered_elem == Cfg_last_read-1) && Writing_done) begin + //start processing of next FM if writing is done already (possible due to unused input elements at the tail end) + //todo: allow for read overlapping between feature maps (i.e., reading first elements from next FM while still writing last window of current FM) + Newest_buffered_elem <= -1; + Current_elem <= 0; + Window_buffer_read_addr_reg <= 0; + First_elem_next_window <= 0; + Writing_done <= 0; + Fetching_done <= 0; + end + end + + if (fetch_cmd) begin + //count up to track which element index is about to be read from the buffer, and where it is located within the buffer + //use increment value calculated by controller + + // absolute buffer address wrap-around + automatic logic signed [$clog2(BUF_ELEM_TOTAL)+1:0] ra = $signed(Window_buffer_read_addr_reg) + $signed(addr_incr); + automatic logic signed [$clog2(BUF_ELEM_TOTAL+1):0] ra_correct = + (ra >= BUF_ELEM_TOTAL)? -BUF_ELEM_TOTAL : + (ra < 0)? BUF_ELEM_TOTAL : 0; + Window_buffer_read_addr_reg <= ra + ra_correct; + + //keep track where we are within a window + Position_in_window <= (Position_in_window != ELEM_PER_WINDOW - 1)? Position_in_window+1 : 0; + + //update first element of next window to allow buffer overwrite up until that point + if (Position_in_window == 0) + First_elem_next_window <= First_elem_next_window + tail_incr; + + //check if this is the last write cycle (Writing_done will be true afterwards) + if (Current_elem == Cfg_last_write) + Fetching_done <= 1; + else + Current_elem <= $signed(Current_elem) + addr_incr; + + // determine if prefetched data will be outstanding in the next cycle + // if we fetch in this cycle -> yes + // if we do not fetch nor write -> do not change + // if we do not fetch but write successfully-> clear outstanding data + Write_cmd <= fetch_cmd; + end + + if (write_ok) + Write_cmd <= fetch_cmd; + + if (write_ok && Fetching_done) begin + //check if this is the last write cycle (Writing_done will be true afterwards) + if (reading_done || (read_ok && (Newest_buffered_elem == Cfg_last_read - 1))) begin + //start processing of next FM if reading is done already, or completes in the same cycle + Newest_buffered_elem <= -1; + Current_elem <= 0; + Window_buffer_read_addr_reg <= 0; + First_elem_next_window <= 0; + Fetching_done <= 0; + end else + Writing_done <= 1; + end + end + end + +endmodule : $TOP_MODULE_NAME$_impl diff --git a/finn-rtllib/swg/swg_template_wrapper_dynamic.v b/finn-rtllib/swg/swg_template_wrapper_dynamic.v new file mode 100644 index 0000000000000000000000000000000000000000..d6f839de43c3c13c6d1a9a772b21d976e2348a08 --- /dev/null +++ b/finn-rtllib/swg/swg_template_wrapper_dynamic.v @@ -0,0 +1,156 @@ +`timescale 1 ns / 1 ps + +module $TOP_MODULE_NAME$ #( + // top-level parameters (set via code-generation) + parameter BIT_WIDTH = $BIT_WIDTH$, + parameter SIMD = $SIMD$, + parameter MMV_IN = $MMV_IN$, + parameter MMV_OUT = $MMV_OUT$, + + parameter CNTR_BITWIDTH = $CNTR_BITWIDTH$, + parameter INCR_BITWIDTH = $INCR_BITWIDTH$, + + // derived constants + parameter BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN, + parameter BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT, + + parameter integer C_s_axi_cfg_DATA_WIDTH = 32, + parameter integer C_s_axi_cfg_ADDR_WIDTH = 6 +) +( + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out_V:s_axi_cfg" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out_V:s_axi_cfg" *) + input ap_rst_n, + input [BUF_IN_WIDTH-1:0] in0_V_TDATA, + input in0_V_TVALID, + output in0_V_TREADY, + output [BUF_OUT_WIDTH-1:0] out_V_TDATA, + output out_V_TVALID, + input out_V_TREADY, + + // Ports of Axi Slave Bus Interface s_axi_cfg + //input wire s_axi_cfg_aclk, + //input wire s_axi_cfg_aresetn, + input wire [C_s_axi_cfg_ADDR_WIDTH-1 : 0] s_axi_cfg_awaddr, + input wire [2 : 0] s_axi_cfg_awprot, + input wire s_axi_cfg_awvalid, + output wire s_axi_cfg_awready, + input wire [C_s_axi_cfg_DATA_WIDTH-1 : 0] s_axi_cfg_wdata, + input wire [(C_s_axi_cfg_DATA_WIDTH/8)-1 : 0] s_axi_cfg_wstrb, + input wire s_axi_cfg_wvalid, + output wire s_axi_cfg_wready, + output wire [1 : 0] s_axi_cfg_bresp, + output wire s_axi_cfg_bvalid, + input wire s_axi_cfg_bready, + input wire [C_s_axi_cfg_ADDR_WIDTH-1 : 0] s_axi_cfg_araddr, + input wire [2 : 0] s_axi_cfg_arprot, + input wire s_axi_cfg_arvalid, + output wire s_axi_cfg_arready, + output wire [C_s_axi_cfg_DATA_WIDTH-1 : 0] s_axi_cfg_rdata, + output wire [1 : 0] s_axi_cfg_rresp, + output wire s_axi_cfg_rvalid, + input wire s_axi_cfg_rready +); + +wire cfg_valid; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_simd; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_kw; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_kh; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_w; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_h; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_simd; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_kw; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_kh; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_w; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_h; +wire [INCR_BITWIDTH-1:0] cfg_incr_tail_w; +wire [INCR_BITWIDTH-1:0] cfg_incr_tail_h; +wire [INCR_BITWIDTH-1:0] cfg_incr_tail_last; +wire [31:0] cfg_last_read; +wire [31:0] cfg_last_write; + +// Instantiation of Axi Bus Interface s_axi_cfg +$TOP_MODULE_NAME$_axilite # ( + .C_S_AXI_DATA_WIDTH(C_s_axi_cfg_DATA_WIDTH), + .C_S_AXI_ADDR_WIDTH(C_s_axi_cfg_ADDR_WIDTH) +) axilite_cfg_inst ( + .S_AXI_ACLK(ap_clk), + .S_AXI_ARESETN(ap_rst_n), + .S_AXI_AWADDR(s_axi_cfg_awaddr), + .S_AXI_AWPROT(s_axi_cfg_awprot), + .S_AXI_AWVALID(s_axi_cfg_awvalid), + .S_AXI_AWREADY(s_axi_cfg_awready), + .S_AXI_WDATA(s_axi_cfg_wdata), + .S_AXI_WSTRB(s_axi_cfg_wstrb), + .S_AXI_WVALID(s_axi_cfg_wvalid), + .S_AXI_WREADY(s_axi_cfg_wready), + .S_AXI_BRESP(s_axi_cfg_bresp), + .S_AXI_BVALID(s_axi_cfg_bvalid), + .S_AXI_BREADY(s_axi_cfg_bready), + .S_AXI_ARADDR(s_axi_cfg_araddr), + .S_AXI_ARPROT(s_axi_cfg_arprot), + .S_AXI_ARVALID(s_axi_cfg_arvalid), + .S_AXI_ARREADY(s_axi_cfg_arready), + .S_AXI_RDATA(s_axi_cfg_rdata), + .S_AXI_RRESP(s_axi_cfg_rresp), + .S_AXI_RVALID(s_axi_cfg_rvalid), + .S_AXI_RREADY(s_axi_cfg_rready), + + .cfg_reg0(cfg_valid), + .cfg_reg1(cfg_cntr_simd), + .cfg_reg2(cfg_cntr_kw), + .cfg_reg3(cfg_cntr_kh), + .cfg_reg4(cfg_cntr_w), + .cfg_reg5(cfg_cntr_h), + .cfg_reg6(cfg_incr_head_simd), + .cfg_reg7(cfg_incr_head_kw), + .cfg_reg8(cfg_incr_head_kh), + .cfg_reg9(cfg_incr_head_w), + .cfg_reg10(cfg_incr_head_h), + .cfg_reg11(cfg_incr_tail_w), + .cfg_reg12(cfg_incr_tail_h), + .cfg_reg13(cfg_incr_tail_last), + .cfg_reg14(cfg_last_read), + .cfg_reg15(cfg_last_write) +); + +$TOP_MODULE_NAME$_impl +#( + .BIT_WIDTH(BIT_WIDTH), + .SIMD(SIMD), + .MMV_IN(MMV_IN), + .MMV_OUT(MMV_OUT), + .CNTR_BITWIDTH(CNTR_BITWIDTH), + .INCR_BITWIDTH(INCR_BITWIDTH) +) +impl +( + .ap_clk(ap_clk), + .ap_rst_n(ap_rst_n), + .in0_V_V_TDATA(in0_V_TDATA), + .in0_V_V_TVALID(in0_V_TVALID), + .in0_V_V_TREADY(in0_V_TREADY), + .out_V_V_TDATA(out_V_TDATA), + .out_V_V_TVALID(out_V_TVALID), + .out_V_V_TREADY(out_V_TREADY), + + .cfg_valid(cfg_valid), + .cfg_cntr_simd(cfg_cntr_simd), + .cfg_cntr_kw(cfg_cntr_kw), + .cfg_cntr_kh(cfg_cntr_kh), + .cfg_cntr_w(cfg_cntr_w), + .cfg_cntr_h(cfg_cntr_h), + .cfg_incr_head_simd(cfg_incr_head_simd), + .cfg_incr_head_kw(cfg_incr_head_kw), + .cfg_incr_head_kh(cfg_incr_head_kh), + .cfg_incr_head_w(cfg_incr_head_w), + .cfg_incr_head_h(cfg_incr_head_h), + .cfg_incr_tail_w(cfg_incr_tail_w), + .cfg_incr_tail_h(cfg_incr_tail_h), + .cfg_incr_tail_last(cfg_incr_tail_last), + .cfg_last_read(cfg_last_read), + .cfg_last_write(cfg_last_write) +); + +endmodule //TOP_MODULE_NAME diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py index 399b36e15021af6f449df3e9ba2acdc699a27647..6b6180707581663d2146de1ecd4a9556e325a04b 100755 --- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py +++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py @@ -81,6 +81,9 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): "inputDataType": ("s", True, ""), "outputDataType": ("s", True, ""), "depthwise": ("i", False, 0, {0, 1}), + # Enable reprogrammable implementation to change FM dimensions, + # stride, or dilation during runtime + "dynamic_mode": ("i", False, 0, {0, 1}), # FPGA resource type for ConvolutionInputGenerator input buffer # auto -- let Vivado decide # block -- use BRAM @@ -457,9 +460,11 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): def prepare_codegen_default(self): # Default implementation style for MMV_out = 1: addressable cyclic buffer # Computing incremental addressing scheme directly.. - template_path = ( - os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_default.sv" - ) + if self.get_nodeattr("dynamic_mode"): + template_select = "/finn-rtllib/swg/swg_template_default_dynamic.sv" + else: + template_select = "/finn-rtllib/swg/swg_template_default.sv" + template_path = os.environ["FINN_ROOT"] + template_select code_gen_dict = {} ifm_ch = self.get_nodeattr("IFMChannels") @@ -590,11 +595,24 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_SIMD"] loop_simd_iterations -= 1 # -1 because state is initial state - code_gen_dict["$LOOP_H_ITERATIONS$"] = [str(loop_h_iterations - 1)] - code_gen_dict["$LOOP_W_ITERATIONS$"] = [str(loop_w_iterations - 1)] - code_gen_dict["$LOOP_KH_ITERATIONS$"] = [str(loop_kh_iterations - 1)] - code_gen_dict["$LOOP_KW_ITERATIONS$"] = [str(loop_kw_iterations - 1)] - code_gen_dict["$LOOP_SIMD_ITERATIONS$"] = [str(loop_simd_iterations - 1)] + code_gen_dict["$LOOP_H_ITERATIONS$"] = [str(loop_h_iterations - 2)] + code_gen_dict["$LOOP_W_ITERATIONS$"] = [str(loop_w_iterations - 2)] + code_gen_dict["$LOOP_KH_ITERATIONS$"] = [str(loop_kh_iterations - 2)] + code_gen_dict["$LOOP_KW_ITERATIONS$"] = [str(loop_kw_iterations - 2)] + code_gen_dict["$LOOP_SIMD_ITERATIONS$"] = [str(loop_simd_iterations - 2)] + + cntr_bitwidth = math.ceil( + math.log2( + max( + loop_h_iterations - 2 + 1, + loop_w_iterations - 2 + 1, + loop_kh_iterations - 2 + 1, + loop_kw_iterations - 2 + 1, + loop_simd_iterations - 2 + 1, + ) + ) + ) + code_gen_dict["$CNTR_BITWIDTH$"] = [str(cntr_bitwidth)] incr_bitwidth = 1 + math.ceil( math.log2( @@ -626,6 +644,11 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): abs(addr_incr_end_row), ) ] + code_gen_dict["$INCR_HEAD_SIMD$"] = [str(addr_incr_end_simd)] + code_gen_dict["$INCR_HEAD_KW$"] = [str(addr_incr_end_window_elem)] + code_gen_dict["$INCR_HEAD_KH$"] = [str(addr_incr_end_window_row)] + code_gen_dict["$INCR_HEAD_W$"] = [str(addr_incr_end_window)] + code_gen_dict["$INCR_HEAD_H$"] = [str(addr_incr_end_row)] code_gen_dict["$ELEM_PER_WINDOW$"] = [str(elem_per_window)] code_gen_dict["$SIMD$"] = [str(simd)] @@ -710,15 +733,22 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") with open(template_path, "r") as f: template = f.read() + if self.get_nodeattr("dynamic_mode"): + template_select = "/finn-rtllib/swg/swg_template_wrapper_dynamic.v" + else: + template_select = "/finn-rtllib/swg/swg_template_wrapper.v" + with open(os.environ["FINN_ROOT"] + template_select, "r") as f: + template_wrapper = f.read() with open( - os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_wrapper.v", "r" + os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_axilite.v", "r" ) as f: - template_wrapper = f.read() + template_axilite = f.read() for key in code_gen_dict: # transform list into long string separated by '\n' code_gen_line = "\n".join(code_gen_dict[key]) template = template.replace(key, code_gen_line) template_wrapper = template_wrapper.replace(key, code_gen_line) + template_axilite = template_axilite.replace(key, code_gen_line) with open( os.path.join( code_gen_dir, self.get_nodeattr("gen_top_module") + "_impl.sv" @@ -734,6 +764,16 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): ) as f: f.write(template_wrapper) + # AXI-Lite reg. file component is only needed for dynamic mode + if self.get_nodeattr("dynamic_mode"): + with open( + os.path.join( + code_gen_dir, self.get_nodeattr("gen_top_module") + "_axilite.v" + ), + "w", + ) as f: + f.write(template_axilite) + # set ipgen_path and ip_path so that HLS-Synth transformation # and stich_ip transformation do not complain self.set_nodeattr("ipgen_path", code_gen_dir) @@ -754,6 +794,8 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): self.get_nodeattr("gen_top_module") + "_wrapper.v", self.get_nodeattr("gen_top_module") + "_impl.sv", ] + if self.get_nodeattr("dynamic_mode"): + verilog_files.append(self.get_nodeattr("gen_top_module") + "_axilite.v") # build the Verilator emu library sim = PyVerilator.build( @@ -771,25 +813,102 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): """Constructs and returns the TCL for node instantiation in Vivado IPI.""" code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") - cmd = [ - "add_files -norecurse %s" - % ( - os.path.join( - code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v" - ) - ), - "add_files -norecurse %s" - % ( - os.path.join( - code_gen_dir, self.get_nodeattr("gen_top_module") + "_impl.sv" - ) - ), - "create_bd_cell -type module -reference %s %s" - % (self.get_nodeattr("gen_top_module"), self.onnx_node.name), + sourcefiles = [ + self.get_nodeattr("gen_top_module") + "_wrapper.v", + self.get_nodeattr("gen_top_module") + "_impl.sv", ] + if self.get_nodeattr("dynamic_mode"): + sourcefiles += [self.get_nodeattr("gen_top_module") + "_axilite.v"] + + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % (f)] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] return cmd + def get_verilog_top_module_intf_names(self): + # Overload default HLSCustomOp implementation to add axilite control IF + """Return a dict of names of input and output interfaces. + The keys reflect the protocols each interface implements: + 'clk', 'rst', 'm_axis', 's_axis', 'aximm', 'axilite'. + Values are lists of tuples (axis, aximm) or names (axilite): + 'axis' tuples correspond to the list of node inputs in order, + each tuple is (interface_name, interface_width_bits). + axilite always assumed to be 32 bits and is not tuple (name only). + Each block must have at most one aximm and one axilite.""" + intf_names = {} + intf_names["clk"] = ["ap_clk"] + intf_names["rst"] = ["ap_rst_n"] + sname = self.hls_sname() + intf_names["s_axis"] = [("in0_" + sname, self.get_instream_width_padded())] + intf_names["m_axis"] = [("out_" + sname, self.get_outstream_width_padded())] + intf_names["aximm"] = [] + if self.get_nodeattr("dynamic_mode"): + intf_names["axilite"] = ["s_axi_cfg"] + else: + intf_names["axilite"] = [] + return intf_names + + def get_dynamic_config(self, ifm_dim, stride=None, dilation=None): + """Returns a configuration dict to re-configure FM dimension during + runtime. Stride and dilation can also be changed. Certain restrictions + apply (e.g. component must be synthesized for largest buffer size).""" + # NOTE: For better driver integration, this functionality could be packaged + # as a standalone function in the future + + k = self.get_nodeattr("ConvKernelDim") + if stride is None: + stride = self.get_nodeattr("Stride") + if dilation is None: + dilation = self.get_nodeattr("Dilation") + + k_h, k_w = k + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ifm_dim_h, ifm_dim_w = ifm_dim + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) + ofm_dim = [ofm_dim_h, ofm_dim_w] + + # update attributes and perform sanity check + original_buffer_depth = self.get_buffer_depth() + self.set_nodeattr("IFMDim", ifm_dim) + self.set_nodeattr("OFMDim", ofm_dim) + self.set_nodeattr("Stride", stride) + self.set_nodeattr("Dilation", dilation) + assert ( + self.get_buffer_depth() <= original_buffer_depth + ), """Error: requested + dynamic configuration does not fit in generated buffer implementation.""" + + # (re-)call codegen and extract new values + # each setting is mapped to an axi-lite register address + template_path, code_gen_dict = self.prepare_codegen_default() + config = { + "cfg_cntr_simd": (1 * 4, int(code_gen_dict["$LOOP_SIMD_ITERATIONS$"][0])), + "cfg_cntr_kw": (2 * 4, int(code_gen_dict["$LOOP_KW_ITERATIONS$"][0])), + "cfg_cntr_kh": (3 * 4, int(code_gen_dict["$LOOP_KH_ITERATIONS$"][0])), + "cfg_cntr_w": (4 * 4, int(code_gen_dict["$LOOP_W_ITERATIONS$"][0])), + "cfg_cntr_h": (5 * 4, int(code_gen_dict["$LOOP_H_ITERATIONS$"][0])), + "cfg_incr_head_simd": (6 * 4, int(code_gen_dict["$INCR_HEAD_SIMD$"][0])), + "cfg_incr_head_kw": (7 * 4, int(code_gen_dict["$INCR_HEAD_KW$"][0])), + "cfg_incr_head_kh": (8 * 4, int(code_gen_dict["$INCR_HEAD_KH$"][0])), + "cfg_incr_head_w": (9 * 4, int(code_gen_dict["$INCR_HEAD_W$"][0])), + "cfg_incr_head_h": (10 * 4, int(code_gen_dict["$INCR_HEAD_H$"][0])), + "cfg_incr_tail_w": (11 * 4, int(code_gen_dict["$TAIL_INCR_W$"][0])), + "cfg_incr_tail_h": (12 * 4, int(code_gen_dict["$TAIL_INCR_H$"][0])), + "cfg_incr_tail_last": (13 * 4, int(code_gen_dict["$TAIL_INCR_LAST$"][0])), + "cfg_last_read": (14 * 4, int(code_gen_dict["$LAST_READ_ELEM$"][0])), + "cfg_last_write": (15 * 4, int(code_gen_dict["$LAST_WRITE_ELEM$"][0])), + } + return config + def code_generation_ipgen(self, model, fpgapart, clk): """Normally: Generates C++ code and tcl script for IP generation. Here: Generates (System-)Verilog code for IP generation.""" diff --git a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..f2d51d9ea60e393e2c146cc8bb161a50d8a4d961 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py @@ -0,0 +1,320 @@ +# Copyright (c) 2022, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from onnx import TensorProto, helper +from pyverilator.util.axi_utils import axilite_write, reset_rtlsim +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.general.im2col import compute_conv_output_dim +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames +from qonnx.util.basic import gen_finn_dt_tensor + +import finn.core.onnx_exec as oxe +from finn.core.rtlsim_exec import rtlsim_exec +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP + + +def make_single_im2col_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, stride, dilation, idt): + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h, ofm_dim_w = ofm_dim + + odt = idt + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_dim_h, ifm_dim_w, ifm_ch] + ) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch] + ) + + im2col_node = helper.make_node( + "Im2Col", + ["inp"], + ["outp"], + domain="finn.custom_op.general", + stride=[stride_h, stride_w], + kernel_size=[k_h, k_w], + input_shape=str((1, ifm_dim_h, ifm_dim_w, ifm_ch)), + dilations=[dilation_h, dilation_w], + pad_amount=[0, 0, 0, 0], + pad_value=0, + ) + graph = helper.make_graph( + nodes=[im2col_node], name="im2col_graph", inputs=[inp], outputs=[outp] + ) + + model = helper.make_model(graph, producer_name="im2col-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + + return model + + +def make_single_slidingwindow_modelwrapper( + k, ifm_ch, ifm_dim, ofm_dim, simd, m, parallel_window, stride, dilation, idt, dw=0 +): + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h, ofm_dim_w = ofm_dim + + odt = idt + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_dim_h, ifm_dim_w, ifm_ch] + ) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch] + ) + + SlidingWindow_node = helper.make_node( + "ConvolutionInputGenerator_rtl", + ["inp"], + ["outp"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ifm_dim_h, ifm_dim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], + SIMD=simd, + M=m, + parallel_window=parallel_window, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], + inputDataType=idt.name, + outputDataType=odt.name, + depthwise=dw, + dynamic_mode=1, + ) + graph = helper.make_graph( + nodes=[SlidingWindow_node], + name="slidingwindow_graph", + inputs=[inp], + outputs=[outp], + ) + + model = helper.make_model(graph, producer_name="slidingwindow-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + + return model + + +def prepare_inputs(input_tensor): + return {"inp": input_tensor} + + +# input datatype +@pytest.mark.parametrize("idt", [DataType["UINT4"]]) +# kernel size +@pytest.mark.parametrize("k", [[3, 3]]) +# input dimension +@pytest.mark.parametrize("ifm_dim_series", [[[32, 32], [16, 16], [8, 8]]]) +# input channels +@pytest.mark.parametrize("ifm_ch", [6]) +# Stride +@pytest.mark.parametrize("stride", [[1, 1]]) +# Dilation +@pytest.mark.parametrize("dilation", [[1, 1]]) +# depthwise +@pytest.mark.parametrize("dw", [0, 1]) +# input channel parallelism ("SIMD") +@pytest.mark.parametrize("simd", [2, 6]) +# parallel_window enable (MMV_out = M*K) +@pytest.mark.parametrize("parallel_window", [0]) +# in/out MMV ("M") +@pytest.mark.parametrize("m", [1]) +@pytest.mark.slow +@pytest.mark.vivado +def test_fpgadataflow_slidingwindow_rtl_dynamic( + idt, k, ifm_dim_series, ifm_ch, stride, dilation, dw, simd, m, parallel_window +): + # Begin test by generating RTL SWG normally for the first FM of the series. + # The following FM dimensions must be equal or smaller than the initial + # dimensions (in terms of required buffer depth). + ifm_dim = ifm_dim_series[0] + + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) + ofm_dim = [ofm_dim_h, ofm_dim_w] + kernel_width = (k_w - 1) * dilation_w + 1 # incl. dilation + kernel_height = (k_h - 1) * dilation_h + 1 # incl. dilation + + if simd > ifm_ch: + pytest.skip("SIMD cannot be larger than number of input channels") + if ifm_ch % simd != 0: + pytest.skip("SIMD must divide number of input channels") + if kernel_height > ifm_dim_h or stride_h > ifm_dim_h: + pytest.skip( + "Illegal convolution configuration: kernel or stride > FM dimension" + ) + if kernel_width > ifm_dim_w or stride_w > ifm_dim_w: + pytest.skip( + "Illegal convolution configuration: kernel or stride > FM dimension" + ) + if (k_h == 1 and (stride_h != 1 or dilation_h != 1)) or ( + k_w == 1 and (stride_w != 1 or dilation_w != 1) + ): + pytest.skip( + """Illegal convolution configuration: + stride or dilation defined for unitary kernel dim""" + ) + if k_h == 1 and k_w == 1 and simd != ifm_ch: + pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)") + if parallel_window and simd != ifm_ch: + pytest.skip("Parallel window requires SIMD=C") + + model = make_single_slidingwindow_modelwrapper( + k=k, + ifm_ch=ifm_ch, + ifm_dim=ifm_dim, + ofm_dim=ofm_dim, + simd=simd, + m=m, + parallel_window=parallel_window, + stride=stride, + dilation=dilation, + idt=idt, + dw=dw, + ) + + # Simulate using stitched-ip-rtlsim so we can use existing infrastructure + # that supports hook functions to re-program configuration before rtlsim + model = model.transform(InsertFIFO(True)) # required for proper simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xc7z020clg400-1", 5)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP("xc7z020clg400-1", 5)) + model.set_metadata_prop("exec_mode", "rtlsim") + + # Helper function that delivers the hook to program the SWG via AXI-Lite + def config_hook(config): + if config is None: + return None + + def write_swg_config(sim): + axi_name = "s_axi_cfg_0_" + # 1. Write config registers to the SWG, dict defines (addr, value) tuples + for config_entry in config.values(): + axilite_write(sim, config_entry[0], config_entry[1], basename=axi_name) + # 2. Set cfg_valid flag (>= 1 cycle) + axilite_write(sim, 0, 1, basename=axi_name) + # 3. Reset component (>= 1 cycle) + reset_rtlsim(sim) + + return write_swg_config + + # Helper function to update tensor dimensions manually because shape inference + # does not work on FINN nodes (they assume well-defined tensor shapes). + def update_tensor_dim(model, tensor_name, new_hw): + shape = model.get_tensor_shape(tensor_name) + shape[1] = new_hw[0] + shape[2] = new_hw[1] + model.set_tensor_shape(tensor_name, shape) + + # Simulate 1 FM for each dimension in the series + for i, ifm_dim in enumerate(ifm_dim_series): + ifm_dim_h, ifm_dim_w = ifm_dim + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) + ofm_dim = [ofm_dim_h, ofm_dim_w] + + config = None + if i > 0: # skip re-programming for initial FM dimension + # Necessary update of node and tensor attributes to make rtlsim work: + swg_node = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")[0] + swg_inst = getCustomOp(swg_node) + update_tensor_dim(model, swg_node.input[0], ifm_dim) + update_tensor_dim(model, swg_node.output[0], ofm_dim) + + # Generate config, also overwrites IFMDim/OFMDim attributes: + config = swg_inst.get_dynamic_config(ifm_dim) + + # Also update FIFO nodes and corresponding tensors + fifo_node = model.get_nodes_by_op_type("StreamingFIFO")[0] + fifo_inst = getCustomOp(fifo_node) + shape = fifo_inst.get_nodeattr("folded_shape") + shape[1] = ifm_dim_h + shape[2] = ifm_dim_w + fifo_inst.set_nodeattr("folded_shape", shape) + update_tensor_dim(model, fifo_node.input[0], ifm_dim) + + fifo_node = model.get_nodes_by_op_type("StreamingFIFO")[1] + fifo_inst = getCustomOp(fifo_node) + shape = fifo_inst.get_nodeattr("folded_shape") + shape[1] = ofm_dim_h + shape[2] = ofm_dim_w + fifo_inst.set_nodeattr("folded_shape", shape) + update_tensor_dim(model, fifo_node.output[0], ofm_dim) + + # Run rtlsim on stitched-ip + x = gen_finn_dt_tensor(idt, (1, ifm_dim_h, ifm_dim_w, ifm_ch)) + context = prepare_inputs(x) + rtlsim_exec(model, context, pre_hook=config_hook(config)) + y_produced = context["outp"] + + # Generate golden result + golden = make_single_im2col_modelwrapper( + k=k, + ifm_ch=ifm_ch, + ifm_dim=ifm_dim, + ofm_dim=ofm_dim, + stride=stride, + dilation=dilation, + idt=idt, + ) + input_dict = prepare_inputs(x) + y_expected = oxe.execute_onnx(golden, input_dict)["outp"] + + # Check result + if dw == 0: + assert (y_produced == y_expected).all() + else: + y_expected = y_expected.reshape( + 1, ofm_dim_h, ofm_dim_w, k_h * k_w, ifm_ch // simd, simd + ) + y_expected = y_expected.transpose(0, 1, 2, 4, 3, 5) + y_expected = y_expected.reshape(1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w) + assert (y_produced == y_expected).all()