diff --git a/finn-rtllib/swg/swg_template_axilite.v b/finn-rtllib/swg/swg_template_axilite.v new file mode 100644 index 0000000000000000000000000000000000000000..9479c7f80d7d82b27141dbe5abcce442049237bd --- /dev/null +++ b/finn-rtllib/swg/swg_template_axilite.v @@ -0,0 +1,567 @@ + +`timescale 1 ns / 1 ps + +module $TOP_MODULE_NAME$_axilite # +( + // Users to add parameters here + + // User parameters ends + // Do not modify the parameters beyond this line + + // Width of S_AXI data bus + parameter integer C_S_AXI_DATA_WIDTH = 32, + // Width of S_AXI address bus + parameter integer C_S_AXI_ADDR_WIDTH = 6 +) +( + // Users to add ports here + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg0, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg1, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg2, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg3, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg4, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg5, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg6, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg7, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg8, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg9, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg10, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg11, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg12, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg13, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg14, + output wire [C_S_AXI_DATA_WIDTH-1:0] cfg_reg15, + + // User ports ends + // Do not modify the ports beyond this line + + // Global Clock Signal + input wire S_AXI_ACLK, + // Global Reset Signal. This Signal is Active LOW + input wire S_AXI_ARESETN, + // Write address (issued by master, acceped by Slave) + input wire [C_S_AXI_ADDR_WIDTH-1 : 0] S_AXI_AWADDR, + // Write channel Protection type. This signal indicates the + // privilege and security level of the transaction, and whether + // the transaction is a data access or an instruction access. + input wire [2 : 0] S_AXI_AWPROT, + // Write address valid. This signal indicates that the master signaling + // valid write address and control information. + input wire S_AXI_AWVALID, + // Write address ready. This signal indicates that the slave is ready + // to accept an address and associated control signals. + output wire S_AXI_AWREADY, + // Write data (issued by master, acceped by Slave) + input wire [C_S_AXI_DATA_WIDTH-1 : 0] S_AXI_WDATA, + // Write strobes. This signal indicates which byte lanes hold + // valid data. There is one write strobe bit for each eight + // bits of the write data bus. + input wire [(C_S_AXI_DATA_WIDTH/8)-1 : 0] S_AXI_WSTRB, + // Write valid. This signal indicates that valid write + // data and strobes are available. + input wire S_AXI_WVALID, + // Write ready. This signal indicates that the slave + // can accept the write data. + output wire S_AXI_WREADY, + // Write response. This signal indicates the status + // of the write transaction. + output wire [1 : 0] S_AXI_BRESP, + // Write response valid. This signal indicates that the channel + // is signaling a valid write response. + output wire S_AXI_BVALID, + // Response ready. This signal indicates that the master + // can accept a write response. + input wire S_AXI_BREADY, + // Read address (issued by master, acceped by Slave) + input wire [C_S_AXI_ADDR_WIDTH-1 : 0] S_AXI_ARADDR, + // Protection type. This signal indicates the privilege + // and security level of the transaction, and whether the + // transaction is a data access or an instruction access. + input wire [2 : 0] S_AXI_ARPROT, + // Read address valid. This signal indicates that the channel + // is signaling valid read address and control information. + input wire S_AXI_ARVALID, + // Read address ready. This signal indicates that the slave is + // ready to accept an address and associated control signals. + output wire S_AXI_ARREADY, + // Read data (issued by slave) + output wire [C_S_AXI_DATA_WIDTH-1 : 0] S_AXI_RDATA, + // Read response. This signal indicates the status of the + // read transfer. + output wire [1 : 0] S_AXI_RRESP, + // Read valid. This signal indicates that the channel is + // signaling the required read data. + output wire S_AXI_RVALID, + // Read ready. This signal indicates that the master can + // accept the read data and response information. + input wire S_AXI_RREADY +); + +// AXI4LITE signals +reg [C_S_AXI_ADDR_WIDTH-1 : 0] axi_awaddr; +reg axi_awready; +reg axi_wready; +reg [1 : 0] axi_bresp; +reg axi_bvalid; +reg [C_S_AXI_ADDR_WIDTH-1 : 0] axi_araddr; +reg axi_arready; +reg [C_S_AXI_DATA_WIDTH-1 : 0] axi_rdata; +reg [1 : 0] axi_rresp; +reg axi_rvalid; + +// Example-specific design signals +// local parameter for addressing 32 bit / 64 bit C_S_AXI_DATA_WIDTH +// ADDR_LSB is used for addressing 32/64 bit registers/memories +// ADDR_LSB = 2 for 32 bits (n downto 2) +// ADDR_LSB = 3 for 64 bits (n downto 3) +localparam integer ADDR_LSB = (C_S_AXI_DATA_WIDTH/32) + 1; +localparam integer OPT_MEM_ADDR_BITS = 3; +//---------------------------------------------- +//-- Signals for user logic register space example +//------------------------------------------------ +//-- Number of Slave Registers 16 +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg0; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg1; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg2; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg3; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg4; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg5; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg6; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg7; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg8; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg9; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg10; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg11; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg12; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg13; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg14; +reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg15; +wire slv_reg_rden; +wire slv_reg_wren; +reg [C_S_AXI_DATA_WIDTH-1:0] reg_data_out; +integer byte_index; +reg aw_en; + +// I/O Connections assignments + +assign S_AXI_AWREADY = axi_awready; +assign S_AXI_WREADY = axi_wready; +assign S_AXI_BRESP = axi_bresp; +assign S_AXI_BVALID = axi_bvalid; +assign S_AXI_ARREADY = axi_arready; +assign S_AXI_RDATA = axi_rdata; +assign S_AXI_RRESP = axi_rresp; +assign S_AXI_RVALID = axi_rvalid; +// Implement axi_awready generation +// axi_awready is asserted for one S_AXI_ACLK clock cycle when both +// S_AXI_AWVALID and S_AXI_WVALID are asserted. axi_awready is +// de-asserted when reset is low. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_awready <= 1'b0; + aw_en <= 1'b1; + end + else + begin + if (~axi_awready && S_AXI_AWVALID && S_AXI_WVALID && aw_en) + begin + // slave is ready to accept write address when + // there is a valid write address and write data + // on the write address and data bus. This design + // expects no outstanding transactions. + axi_awready <= 1'b1; + aw_en <= 1'b0; + end + else if (S_AXI_BREADY && axi_bvalid) + begin + aw_en <= 1'b1; + axi_awready <= 1'b0; + end + else + begin + axi_awready <= 1'b0; + end + end +end + +// Implement axi_awaddr latching +// This process is used to latch the address when both +// S_AXI_AWVALID and S_AXI_WVALID are valid. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_awaddr <= 0; + end + else + begin + if (~axi_awready && S_AXI_AWVALID && S_AXI_WVALID && aw_en) + begin + // Write Address latching + axi_awaddr <= S_AXI_AWADDR; + end + end +end + +// Implement axi_wready generation +// axi_wready is asserted for one S_AXI_ACLK clock cycle when both +// S_AXI_AWVALID and S_AXI_WVALID are asserted. axi_wready is +// de-asserted when reset is low. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_wready <= 1'b0; + end + else + begin + if (~axi_wready && S_AXI_WVALID && S_AXI_AWVALID && aw_en ) + begin + // slave is ready to accept write data when + // there is a valid write address and write data + // on the write address and data bus. This design + // expects no outstanding transactions. + axi_wready <= 1'b1; + end + else + begin + axi_wready <= 1'b0; + end + end +end + +// Implement memory mapped register select and write logic generation +// The write data is accepted and written to memory mapped registers when +// axi_awready, S_AXI_WVALID, axi_wready and S_AXI_WVALID are asserted. Write strobes are used to +// select byte enables of slave registers while writing. +// These registers are cleared when reset (active low) is applied. +// Slave register write enable is asserted when valid address and data are available +// and the slave is ready to accept the write address and write data. +assign slv_reg_wren = axi_wready && S_AXI_WVALID && axi_awready && S_AXI_AWVALID; + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + slv_reg0 <= 0; + slv_reg1 <= 0; + slv_reg2 <= 0; + slv_reg3 <= 0; + slv_reg4 <= 0; + slv_reg5 <= 0; + slv_reg6 <= 0; + slv_reg7 <= 0; + slv_reg8 <= 0; + slv_reg9 <= 0; + slv_reg10 <= 0; + slv_reg11 <= 0; + slv_reg12 <= 0; + slv_reg13 <= 0; + slv_reg14 <= 0; + slv_reg15 <= 0; + end + else begin + if (slv_reg_wren) + begin + case ( axi_awaddr[ADDR_LSB+OPT_MEM_ADDR_BITS:ADDR_LSB] ) + 4'h0: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 0 + slv_reg0[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h1: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 1 + slv_reg1[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h2: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 2 + slv_reg2[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h3: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 3 + slv_reg3[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h4: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 4 + slv_reg4[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h5: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 5 + slv_reg5[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h6: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 6 + slv_reg6[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h7: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 7 + slv_reg7[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h8: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 8 + slv_reg8[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'h9: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 9 + slv_reg9[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hA: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 10 + slv_reg10[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hB: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 11 + slv_reg11[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hC: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 12 + slv_reg12[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hD: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 13 + slv_reg13[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hE: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 14 + slv_reg14[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + 4'hF: + for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) + if ( S_AXI_WSTRB[byte_index] == 1 ) begin + // Respective byte enables are asserted as per write strobes + // Slave register 15 + slv_reg15[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; + end + default : begin + slv_reg0 <= slv_reg0; + slv_reg1 <= slv_reg1; + slv_reg2 <= slv_reg2; + slv_reg3 <= slv_reg3; + slv_reg4 <= slv_reg4; + slv_reg5 <= slv_reg5; + slv_reg6 <= slv_reg6; + slv_reg7 <= slv_reg7; + slv_reg8 <= slv_reg8; + slv_reg9 <= slv_reg9; + slv_reg10 <= slv_reg10; + slv_reg11 <= slv_reg11; + slv_reg12 <= slv_reg12; + slv_reg13 <= slv_reg13; + slv_reg14 <= slv_reg14; + slv_reg15 <= slv_reg15; + end + endcase + end + end +end + +// Implement write response logic generation +// The write response and response valid signals are asserted by the slave +// when axi_wready, S_AXI_WVALID, axi_wready and S_AXI_WVALID are asserted. +// This marks the acceptance of address and indicates the status of +// write transaction. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_bvalid <= 0; + axi_bresp <= 2'b0; + end + else + begin + if (axi_awready && S_AXI_AWVALID && ~axi_bvalid && axi_wready && S_AXI_WVALID) + begin + // indicates a valid write response is available + axi_bvalid <= 1'b1; + axi_bresp <= 2'b0; // 'OKAY' response + end // work error responses in future + else + begin + if (S_AXI_BREADY && axi_bvalid) + //check if bready is asserted while bvalid is high) + //(there is a possibility that bready is always asserted high) + begin + axi_bvalid <= 1'b0; + end + end + end +end + +// Implement axi_arready generation +// axi_arready is asserted for one S_AXI_ACLK clock cycle when +// S_AXI_ARVALID is asserted. axi_awready is +// de-asserted when reset (active low) is asserted. +// The read address is also latched when S_AXI_ARVALID is +// asserted. axi_araddr is reset to zero on reset assertion. + +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_arready <= 1'b0; + axi_araddr <= 32'b0; + end + else + begin + if (~axi_arready && S_AXI_ARVALID) + begin + // indicates that the slave has acceped the valid read address + axi_arready <= 1'b1; + // Read address latching + axi_araddr <= S_AXI_ARADDR; + end + else + begin + axi_arready <= 1'b0; + end + end +end + +// Implement axi_arvalid generation +// axi_rvalid is asserted for one S_AXI_ACLK clock cycle when both +// S_AXI_ARVALID and axi_arready are asserted. The slave registers +// data are available on the axi_rdata bus at this instance. The +// assertion of axi_rvalid marks the validity of read data on the +// bus and axi_rresp indicates the status of read transaction.axi_rvalid +// is deasserted on reset (active low). axi_rresp and axi_rdata are +// cleared to zero on reset (active low). +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_rvalid <= 0; + axi_rresp <= 0; + end + else + begin + if (axi_arready && S_AXI_ARVALID && ~axi_rvalid) + begin + // Valid read data is available at the read data bus + axi_rvalid <= 1'b1; + axi_rresp <= 2'b0; // 'OKAY' response + end + else if (axi_rvalid && S_AXI_RREADY) + begin + // Read data is accepted by the master + axi_rvalid <= 1'b0; + end + end +end + +// Implement memory mapped register select and read logic generation +// Slave register read enable is asserted when valid address is available +// and the slave is ready to accept the read address. +assign slv_reg_rden = axi_arready & S_AXI_ARVALID & ~axi_rvalid; +always @(*) +begin + // Address decoding for reading registers + case ( axi_araddr[ADDR_LSB+OPT_MEM_ADDR_BITS:ADDR_LSB] ) + 4'h0 : reg_data_out <= slv_reg0; + 4'h1 : reg_data_out <= slv_reg1; + 4'h2 : reg_data_out <= slv_reg2; + 4'h3 : reg_data_out <= slv_reg3; + 4'h4 : reg_data_out <= slv_reg4; + 4'h5 : reg_data_out <= slv_reg5; + 4'h6 : reg_data_out <= slv_reg6; + 4'h7 : reg_data_out <= slv_reg7; + 4'h8 : reg_data_out <= slv_reg8; + 4'h9 : reg_data_out <= slv_reg9; + 4'hA : reg_data_out <= slv_reg10; + 4'hB : reg_data_out <= slv_reg11; + 4'hC : reg_data_out <= slv_reg12; + 4'hD : reg_data_out <= slv_reg13; + 4'hE : reg_data_out <= slv_reg14; + 4'hF : reg_data_out <= slv_reg15; + default : reg_data_out <= 0; + endcase +end + +// Output register or memory read data +always @( posedge S_AXI_ACLK ) +begin + if ( S_AXI_ARESETN == 1'b0 ) + begin + axi_rdata <= 0; + end + else + begin + // When there is a valid read address (S_AXI_ARVALID) with + // acceptance of read address by the slave (axi_arready), + // output the read dada + if (slv_reg_rden) + begin + axi_rdata <= reg_data_out; // register read data + end + end +end + +// Add user logic here +assign cfg_reg0 = slv_reg0; +assign cfg_reg1 = slv_reg1; +assign cfg_reg2 = slv_reg2; +assign cfg_reg3 = slv_reg3; +assign cfg_reg4 = slv_reg4; +assign cfg_reg5 = slv_reg5; +assign cfg_reg6 = slv_reg6; +assign cfg_reg7 = slv_reg7; +assign cfg_reg8 = slv_reg8; +assign cfg_reg9 = slv_reg9; +assign cfg_reg10 = slv_reg10; +assign cfg_reg11 = slv_reg11; +assign cfg_reg12 = slv_reg12; +assign cfg_reg13 = slv_reg13; +assign cfg_reg14 = slv_reg14; +assign cfg_reg15 = slv_reg15; +// User logic ends + +endmodule diff --git a/finn-rtllib/swg/swg_template_default.sv b/finn-rtllib/swg/swg_template_default.sv index 97517438a0c261e4488b74a677a352f9dc51743b..06e65e911100dd7d3d8879b014a6d59713eb9bbd 100644 --- a/finn-rtllib/swg/swg_template_default.sv +++ b/finn-rtllib/swg/swg_template_default.sv @@ -36,7 +36,6 @@ module $TOP_MODULE_NAME$_controller #( int unsigned LOOP_SIMD_ITERATIONS = $LOOP_SIMD_ITERATIONS$, int unsigned INCR_BITWIDTH = $INCR_BITWIDTH$, - bit [INCR_BITWIDTH-1:0] ADDR_INCREMENT_MAP[6] = $ADDR_INCREMENT_MAP$, bit IS_DEPTHWISE = $IS_DEPTHWISE$ )( @@ -60,26 +59,31 @@ module $TOP_MODULE_NAME$_controller #( state_e State = $INNERMOST_STATE$; state_e state_next; - logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] Counter_loop_h = LOOP_H_ITERATIONS-1; - logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] Counter_loop_w = LOOP_W_ITERATIONS-1; - logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] Counter_loop_kh = LOOP_KH_ITERATIONS-1; - logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] Counter_loop_kw = LOOP_KW_ITERATIONS-1; - logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] Counter_loop_simd = LOOP_SIMD_ITERATIONS-1; - - assign addr_incr = ADDR_INCREMENT_MAP[State]; + logic signed [$clog2(LOOP_H_ITERATIONS +2)+1-1:0] Counter_loop_h = LOOP_H_ITERATIONS; + logic signed [$clog2(LOOP_W_ITERATIONS +2)+1-1:0] Counter_loop_w = LOOP_W_ITERATIONS; + logic signed [$clog2(LOOP_KH_ITERATIONS +2)+1-1:0] Counter_loop_kh = LOOP_KH_ITERATIONS; + logic signed [$clog2(LOOP_KW_ITERATIONS +2)+1-1:0] Counter_loop_kw = LOOP_KW_ITERATIONS; + logic signed [$clog2(LOOP_SIMD_ITERATIONS+2)+1-1:0] Counter_loop_simd = LOOP_SIMD_ITERATIONS; + + // combinational logic for addr_incr generation + always_comb begin : blkHead + unique case (State) + 0 : addr_incr = 0; + 1 : addr_incr = $HEAD_INCR_SIMD$; + 2 : addr_incr = $HEAD_INCR_KW$; + 3 : addr_incr = $HEAD_INCR_KH$; + 4 : addr_incr = $HEAD_INCR_W$; + 5 : addr_incr = $HEAD_INCR_H$; + endcase + end // combinational logic for tail_incr generation uwire tail_incr_inner_condition = IS_DEPTHWISE? (Counter_loop_kh >= 0) : 0; - always_comb begin : blkTail - if (tail_incr_inner_condition) - tail_incr = 1; - else if (Counter_loop_w >= 0) - tail_incr = $TAIL_INCR_W$; - else if (Counter_loop_h >= 0) - tail_incr = $TAIL_INCR_H$; - else - tail_incr = $TAIL_INCR_LAST$; - end + assign tail_incr = + tail_incr_inner_condition? 1 : + Counter_loop_w >= 0? $TAIL_INCR_W$ : + Counter_loop_h >= 0? $TAIL_INCR_H$ : + /* else */ $TAIL_INCR_LAST$; // combinational next state logic always_comb begin : blkState @@ -101,29 +105,29 @@ module $TOP_MODULE_NAME$_controller #( always_ff @ (posedge clk) begin if(!rst_n) begin State <= $INNERMOST_STATE$; - Counter_loop_h <= LOOP_H_ITERATIONS-1; - Counter_loop_w <= LOOP_W_ITERATIONS-1; - Counter_loop_kh <= LOOP_KH_ITERATIONS-1; - Counter_loop_kw <= LOOP_KW_ITERATIONS-1; - Counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; + Counter_loop_h <= LOOP_H_ITERATIONS; + Counter_loop_w <= LOOP_W_ITERATIONS; + Counter_loop_kh <= LOOP_KH_ITERATIONS; + Counter_loop_kw <= LOOP_KW_ITERATIONS; + Counter_loop_simd <= LOOP_SIMD_ITERATIONS; end else if(advance) begin State <= state_next; if (State == $INNERMOST_STATE$) begin if(Counter_loop_simd >= 0) Counter_loop_simd <= Counter_loop_simd-1; else begin - Counter_loop_simd <= LOOP_SIMD_ITERATIONS-1; + Counter_loop_simd <= LOOP_SIMD_ITERATIONS; if(Counter_loop_kw >= 0) Counter_loop_kw <= Counter_loop_kw-1; else begin - Counter_loop_kw <= LOOP_KW_ITERATIONS-1; + Counter_loop_kw <= LOOP_KW_ITERATIONS; if(Counter_loop_kh >= 0) Counter_loop_kh <= Counter_loop_kh-1; else begin - Counter_loop_kh <= LOOP_KH_ITERATIONS-1; + Counter_loop_kh <= LOOP_KH_ITERATIONS; if(Counter_loop_w >= 0) Counter_loop_w <= Counter_loop_w-1; else begin - Counter_loop_w <= LOOP_W_ITERATIONS-1; + Counter_loop_w <= LOOP_W_ITERATIONS; if(Counter_loop_h >= 0) Counter_loop_h <= Counter_loop_h-1; - else Counter_loop_h <= LOOP_H_ITERATIONS-1; + else Counter_loop_h <= LOOP_H_ITERATIONS; end end end @@ -139,7 +143,6 @@ module $TOP_MODULE_NAME$_cyclic_buffer_addressable #( int unsigned DEPTH )( input logic clk, - input logic rst_n, input logic write_enable, input logic [$clog2(DEPTH)-1:0] write_addr, @@ -182,7 +185,7 @@ module $TOP_MODULE_NAME$_impl #( input logic out_V_V_TREADY, output logic [BIT_WIDTH * SIMD * MMV_OUT-1:0] out_V_V_TDATA ); - // derived Constants + // derived constants localparam int unsigned BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; localparam int unsigned BUF_OUT_ELEM_WIDTH = BIT_WIDTH * SIMD; localparam int unsigned BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; @@ -199,7 +202,6 @@ module $TOP_MODULE_NAME$_impl #( .DEPTH(BUF_ELEM_TOTAL) ) window_buffer_inst ( .clk(ap_clk), - .rst_n(ap_rst_n), .write_enable(window_buffer_write_enable), .write_addr(window_buffer_write_addr), @@ -234,6 +236,15 @@ module $TOP_MODULE_NAME$_impl #( logic [$clog2(BUF_ELEM_TOTAL)-1:0] Window_buffer_write_addr_reg = 0; // Control signals/registers + logic Write_cmd = 0; + logic Writing_done = 0; + uwire write_ok = Write_cmd && out_V_V_TREADY; + uwire write_blocked = Write_cmd && !out_V_V_TREADY; + + logic Fetching_done = 0; + uwire fetch_cmd = !($signed(Current_elem) > Newest_buffered_elem) && !write_blocked && !Fetching_done; + + uwire reading_done = Newest_buffered_elem == LAST_READ_ELEM; uwire read_cmd = !reading_done && ( // if there is still an input element left to read Fetching_done || ( // if fetching is done (e.g. for skipped rows at FM end due to stride) @@ -242,15 +253,6 @@ module $TOP_MODULE_NAME$_impl #( ) // (over-)write to buffer if oldest buffered element will no longer be needed ); uwire read_ok = read_cmd && in0_V_V_TVALID; - uwire reading_done = Newest_buffered_elem == LAST_READ_ELEM; - - uwire fetch_cmd = !($signed(Current_elem) > Newest_buffered_elem) && !write_blocked && !Fetching_done; - logic Fetching_done = 0; - - logic Write_cmd = 0; - logic Writing_done = 0; - uwire write_ok = Write_cmd && out_V_V_TREADY; - uwire write_blocked = Write_cmd && !out_V_V_TREADY;; //assign buffer control assign window_buffer_write_addr = Window_buffer_write_addr_reg; diff --git a/finn-rtllib/swg/swg_template_default_dynamic.sv b/finn-rtllib/swg/swg_template_default_dynamic.sv new file mode 100644 index 0000000000000000000000000000000000000000..eb53978b580a4753bbea6c8478f35912deb812b4 --- /dev/null +++ b/finn-rtllib/swg/swg_template_default_dynamic.sv @@ -0,0 +1,416 @@ +module $TOP_MODULE_NAME$_controller #( + int unsigned CNTR_BITWIDTH, + int unsigned INCR_BITWIDTH, + + bit IS_DEPTHWISE = $IS_DEPTHWISE$ +)( + input logic clk, + input logic rst_n, + + input logic advance, + output logic [INCR_BITWIDTH-1:0] addr_incr, + output logic [INCR_BITWIDTH-1:0] tail_incr, + + input logic cfg_valid, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_simd, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_kw, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_kh, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_w, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_simd, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_kw, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_kh, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_w, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_w, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_last +); + + // (dynamic) configuration registers + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_simd = $LOOP_SIMD_ITERATIONS$; + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_kw = $LOOP_KW_ITERATIONS$; + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_kh = $LOOP_KH_ITERATIONS$; + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_w = $LOOP_W_ITERATIONS$; + logic [CNTR_BITWIDTH-1:0] Cfg_cntr_h = $LOOP_H_ITERATIONS$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_simd = $HEAD_INCR_SIMD$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_kw = $HEAD_INCR_KW$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_kh = $HEAD_INCR_KH$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_w = $HEAD_INCR_W$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_head_h = $HEAD_INCR_H$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_tail_w = $TAIL_INCR_W$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_tail_h = $TAIL_INCR_H$; + logic [INCR_BITWIDTH-1:0] Cfg_incr_tail_last = $TAIL_INCR_LAST$; + + // configuration reset/set logic + always_ff @ (posedge clk) begin + if(cfg_valid) begin + Cfg_cntr_simd <= cfg_cntr_simd; + Cfg_cntr_kw <= cfg_cntr_kw; + Cfg_cntr_kh <= cfg_cntr_kh; + Cfg_cntr_w <= cfg_cntr_w; + Cfg_cntr_h <= cfg_cntr_h; + Cfg_incr_head_simd <= cfg_incr_head_simd; + Cfg_incr_head_kw <= cfg_incr_head_kw; + Cfg_incr_head_kh <= cfg_incr_head_kh; + Cfg_incr_head_w <= cfg_incr_head_w; + Cfg_incr_head_h <= cfg_incr_head_h; + Cfg_incr_tail_w <= cfg_incr_tail_w; + Cfg_incr_tail_h <= cfg_incr_tail_h; + Cfg_incr_tail_last <= cfg_incr_tail_last; + end + end + + // state and counters + typedef enum logic [2:0] { + STATE_START, + STATE_LOOP_SIMD, + STATE_LOOP_KW, + STATE_LOOP_KH, + STATE_LOOP_W, + STATE_LOOP_H + } state_e; + state_e State = $INNERMOST_STATE$; + state_e state_next; + + logic signed [$clog2($LOOP_H_ITERATIONS$ +2)+1-1:0] Counter_loop_h = $LOOP_H_ITERATIONS$; + logic signed [$clog2($LOOP_W_ITERATIONS$ +2)+1-1:0] Counter_loop_w = $LOOP_W_ITERATIONS$; + logic signed [$clog2($LOOP_KH_ITERATIONS$ +2)+1-1:0] Counter_loop_kh = $LOOP_KH_ITERATIONS$; + logic signed [$clog2($LOOP_KW_ITERATIONS$ +2)+1-1:0] Counter_loop_kw = $LOOP_KW_ITERATIONS$; + logic signed [$clog2($LOOP_SIMD_ITERATIONS$+2)+1-1:0] Counter_loop_simd = $LOOP_SIMD_ITERATIONS$; + + // combinational logic for addr_incr generation + always_comb begin : blkHead + unique case (State) + 0 : addr_incr = 0; + 1 : addr_incr = Cfg_incr_head_simd; + 2 : addr_incr = Cfg_incr_head_kw; + 3 : addr_incr = Cfg_incr_head_kh; + 4 : addr_incr = Cfg_incr_head_w; + 5 : addr_incr = Cfg_incr_head_h; + endcase + end + + // combinational logic for tail_incr generation + uwire tail_incr_inner_condition = IS_DEPTHWISE? (Counter_loop_kh >= 0) : 0; + assign tail_incr = + tail_incr_inner_condition? 1 : + Counter_loop_w >= 0? Cfg_incr_tail_w : + Counter_loop_h >= 0? Cfg_incr_tail_h : + /* else */ Cfg_incr_tail_last; + + // combinational next state logic + always_comb begin : blkState + state_next = State; + if(State != $INNERMOST_STATE$) state_next = $INNERMOST_STATE$; + else begin + if(Counter_loop_simd < 0) begin + state_next = + (Counter_loop_kw >= 0)? STATE_LOOP_KW : + (Counter_loop_kh >= 0)? STATE_LOOP_KH : + (Counter_loop_w >= 0)? STATE_LOOP_W : + (Counter_loop_h >= 0)? STATE_LOOP_H : + /* else */ STATE_START; + end + end + end : blkState + + // sequential logic + always_ff @ (posedge clk) begin + if(!rst_n) begin + State <= $INNERMOST_STATE$; + Counter_loop_h <= Cfg_cntr_h; + Counter_loop_w <= Cfg_cntr_w; + Counter_loop_kh <= Cfg_cntr_kh; + Counter_loop_kw <= Cfg_cntr_kw; + Counter_loop_simd <= Cfg_cntr_simd; + end + else if(advance) begin + State <= state_next; + if (State == $INNERMOST_STATE$) begin + if(Counter_loop_simd >= 0) Counter_loop_simd <= Counter_loop_simd-1; + else begin + Counter_loop_simd <= Cfg_cntr_simd; + if(Counter_loop_kw >= 0) Counter_loop_kw <= Counter_loop_kw-1; + else begin + Counter_loop_kw <= Cfg_cntr_kw; + if(Counter_loop_kh >= 0) Counter_loop_kh <= Counter_loop_kh-1; + else begin + Counter_loop_kh <= Cfg_cntr_kh; + if(Counter_loop_w >= 0) Counter_loop_w <= Counter_loop_w-1; + else begin + Counter_loop_w <= Cfg_cntr_w; + if(Counter_loop_h >= 0) Counter_loop_h <= Counter_loop_h-1; + else Counter_loop_h <= Cfg_cntr_h; + end + end + end + end + end + end + end + +endmodule : $TOP_MODULE_NAME$_controller + +module $TOP_MODULE_NAME$_cyclic_buffer_addressable #( + int unsigned WIDTH, + int unsigned DEPTH +)( + input logic clk, + + input logic write_enable, + input logic [$clog2(DEPTH)-1:0] write_addr, + input logic [WIDTH-1:0] data_in, + + input logic read_enable, + input logic [$clog2(DEPTH)-1:0] read_addr, // absolute (!) read address of cyclic buffer + output logic [WIDTH-1:0] data_out +); + + $RAM_STYLE$ logic [WIDTH-1:0] Ram[DEPTH]; + logic [WIDTH-1:0] Out = 'x; + always_ff @(posedge clk) begin + if (read_enable) Out <= Ram[read_addr]; + if (write_enable) Ram[write_addr] <= data_in; + end + assign data_out = Out; + +endmodule : $TOP_MODULE_NAME$_cyclic_buffer_addressable + +module $TOP_MODULE_NAME$_impl #( + int BIT_WIDTH, + int SIMD, + int MMV_IN, + int MMV_OUT, + int unsigned CNTR_BITWIDTH, + int unsigned INCR_BITWIDTH, + + int LAST_READ_ELEM = $LAST_READ_ELEM$, + int LAST_WRITE_ELEM = $LAST_WRITE_ELEM$, + int BUF_ELEM_TOTAL = $BUF_ELEM_TOTAL$, + int ELEM_PER_WINDOW = $ELEM_PER_WINDOW$ +)( + input logic ap_clk, + input logic ap_rst_n, + + input logic in0_V_V_TVALID, + output logic in0_V_V_TREADY, + input logic [BIT_WIDTH * SIMD * MMV_IN-1:0] in0_V_V_TDATA, + + output logic out_V_V_TVALID, + input logic out_V_V_TREADY, + output logic [BIT_WIDTH * SIMD * MMV_OUT-1:0] out_V_V_TDATA, + + input logic cfg_valid, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_simd, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_kw, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_kh, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_w, + input logic [CNTR_BITWIDTH-1:0] cfg_cntr_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_simd, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_kw, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_kh, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_w, + input logic [INCR_BITWIDTH-1:0] cfg_incr_head_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_w, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_h, + input logic [INCR_BITWIDTH-1:0] cfg_incr_tail_last, + input logic [31:0] cfg_last_read, + input logic [31:0] cfg_last_write +); + // derived constants + localparam int unsigned BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN; + localparam int unsigned BUF_OUT_ELEM_WIDTH = BIT_WIDTH * SIMD; + localparam int unsigned BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT; + + // (dynamic) configuration registers + logic [31:0] Cfg_last_read = LAST_READ_ELEM; + logic [31:0] Cfg_last_write = LAST_WRITE_ELEM; + + // configuration reset/set logic + always_ff @ (posedge ap_clk) begin + if(cfg_valid) begin + Cfg_last_read <= cfg_last_read; + Cfg_last_write <= cfg_last_write; + end + end + + // main buffer instantiation + uwire [BUF_IN_WIDTH -1:0] window_buffer_in; + uwire [BUF_OUT_WIDTH-1:0] window_buffer_out; + uwire window_buffer_write_enable; + uwire window_buffer_read_enable; + uwire [$clog2(BUF_ELEM_TOTAL)-1:0] window_buffer_write_addr; + uwire [$clog2(BUF_ELEM_TOTAL)-1:0] window_buffer_read_addr; + $TOP_MODULE_NAME$_cyclic_buffer_addressable #( + .WIDTH(BUF_IN_WIDTH), + .DEPTH(BUF_ELEM_TOTAL) + ) window_buffer_inst ( + .clk(ap_clk), + + .write_enable(window_buffer_write_enable), + .write_addr(window_buffer_write_addr), + .data_in(window_buffer_in), + + .read_enable(window_buffer_read_enable), + .read_addr(window_buffer_read_addr), + .data_out(window_buffer_out) + ); + + //controller instantiation + uwire advance_controller; + uwire signed [INCR_BITWIDTH-1:0] addr_incr; + uwire [INCR_BITWIDTH-1:0] tail_incr; + $TOP_MODULE_NAME$_controller #( + .CNTR_BITWIDTH(CNTR_BITWIDTH), + .INCR_BITWIDTH(INCR_BITWIDTH) + ) controller_inst ( + .clk(ap_clk), + .rst_n(ap_rst_n), + .advance(advance_controller), + .addr_incr(addr_incr), + .tail_incr(tail_incr), + + .cfg_valid(cfg_valid), + .cfg_cntr_simd(cfg_cntr_simd), + .cfg_cntr_kw(cfg_cntr_kw), + .cfg_cntr_kh(cfg_cntr_kh), + .cfg_cntr_w(cfg_cntr_w), + .cfg_cntr_h(cfg_cntr_h), + .cfg_incr_head_simd(cfg_incr_head_simd), + .cfg_incr_head_kw(cfg_incr_head_kw), + .cfg_incr_head_kh(cfg_incr_head_kh), + .cfg_incr_head_w(cfg_incr_head_w), + .cfg_incr_head_h(cfg_incr_head_h), + .cfg_incr_tail_w(cfg_incr_tail_w), + .cfg_incr_tail_h(cfg_incr_tail_h), + .cfg_incr_tail_last(cfg_incr_tail_last) + ); + + // Counters/address registers + // Add a sign bit even to (most) unsigned counters and Window_buffer_read_addr_reg, + // so we can use automatic sign extension and simplify calculations w/ signed increment. + // Alternatively, we could manually sign-extend and shave off a bit here or there. + logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] Newest_buffered_elem = -1; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] Current_elem = 0; + logic [$clog2(LAST_READ_ELEM+1)+1-1:0] First_elem_next_window = 0; + logic [$clog2(ELEM_PER_WINDOW) -1:0] Position_in_window = 0; + logic [$clog2(BUF_ELEM_TOTAL)+1 -1:0] Window_buffer_read_addr_reg = 0; + logic [$clog2(BUF_ELEM_TOTAL)-1:0] Window_buffer_write_addr_reg = 0; + + // Control signals/registers + logic Write_cmd = 0; + logic Writing_done = 0; + uwire write_ok = Write_cmd && out_V_V_TREADY; + uwire write_blocked = Write_cmd && !out_V_V_TREADY; + + logic Fetching_done = 0; + uwire fetch_cmd = !($signed(Current_elem) > Newest_buffered_elem) && !write_blocked && !Fetching_done; + + uwire reading_done = Newest_buffered_elem == Cfg_last_read; + uwire read_cmd = + !reading_done && ( // if there is still an input element left to read + Fetching_done || ( // if fetching is done (e.g. for skipped rows at FM end due to stride) + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(First_elem_next_window) && + $signed(((Newest_buffered_elem - (BUF_ELEM_TOTAL - 1)))) < $signed(Current_elem) + ) // (over-)write to buffer if oldest buffered element will no longer be needed + ); + uwire read_ok = read_cmd && in0_V_V_TVALID; + + //assign buffer control + assign window_buffer_write_addr = Window_buffer_write_addr_reg; + assign window_buffer_read_addr = Window_buffer_read_addr_reg; + assign window_buffer_write_enable = read_ok; + assign window_buffer_read_enable = fetch_cmd; + assign advance_controller = fetch_cmd; + + //assign I/O ports + assign window_buffer_in = in0_V_V_TDATA; + assign out_V_V_TDATA = window_buffer_out; + assign in0_V_V_TREADY = ap_rst_n && read_ok; //only asserted if data is available and we can store it (allowed) + assign out_V_V_TVALID = ap_rst_n && Write_cmd; //only asserted if we have data available and it has not been read yet (don't wait for READY from sink) + + //main process for advancing counters + always_ff @(posedge ap_clk) begin + if(!ap_rst_n) begin + Newest_buffered_elem <= -1; + Current_elem <= 0; + First_elem_next_window <= 0; + Position_in_window <= 0; + Window_buffer_read_addr_reg <= 0; + Window_buffer_write_addr_reg <= 0; + Fetching_done <= 0; + Write_cmd <= 0; + Writing_done <= 0; + end + else begin + if (read_ok) begin + Window_buffer_write_addr_reg <= (Window_buffer_write_addr_reg == BUF_ELEM_TOTAL-1)? 0 : Window_buffer_write_addr_reg + 1; + Newest_buffered_elem <= Newest_buffered_elem+1; + + if (Newest_buffered_elem == Cfg_last_read-1) begin + Window_buffer_write_addr_reg <= 0; + end + //check if this is the last read cycle (reading_done will be true afterwards) + if ((Newest_buffered_elem == Cfg_last_read-1) && Writing_done) begin + //start processing of next FM if writing is done already (possible due to unused input elements at the tail end) + //todo: allow for read overlapping between feature maps (i.e., reading first elements from next FM while still writing last window of current FM) + Newest_buffered_elem <= -1; + Current_elem <= 0; + Window_buffer_read_addr_reg <= 0; + First_elem_next_window <= 0; + Writing_done <= 0; + Fetching_done <= 0; + end + end + + if (fetch_cmd) begin + //count up to track which element index is about to be read from the buffer, and where it is located within the buffer + //use increment value calculated by controller + + // absolute buffer address wrap-around + automatic logic signed [$clog2(BUF_ELEM_TOTAL)+1:0] ra = $signed(Window_buffer_read_addr_reg) + $signed(addr_incr); + automatic logic signed [$clog2(BUF_ELEM_TOTAL+1):0] ra_correct = + (ra >= BUF_ELEM_TOTAL)? -BUF_ELEM_TOTAL : + (ra < 0)? BUF_ELEM_TOTAL : 0; + Window_buffer_read_addr_reg <= ra + ra_correct; + + //keep track where we are within a window + Position_in_window <= (Position_in_window != ELEM_PER_WINDOW - 1)? Position_in_window+1 : 0; + + //update first element of next window to allow buffer overwrite up until that point + if (Position_in_window == 0) + First_elem_next_window <= First_elem_next_window + tail_incr; + + //check if this is the last write cycle (Writing_done will be true afterwards) + if (Current_elem == Cfg_last_write) + Fetching_done <= 1; + else + Current_elem <= $signed(Current_elem) + addr_incr; + + // determine if prefetched data will be outstanding in the next cycle + // if we fetch in this cycle -> yes + // if we do not fetch nor write -> do not change + // if we do not fetch but write successfully-> clear outstanding data + Write_cmd <= fetch_cmd; + end + + if (write_ok) + Write_cmd <= fetch_cmd; + + if (write_ok && Fetching_done) begin + //check if this is the last write cycle (Writing_done will be true afterwards) + if (reading_done || (read_ok && (Newest_buffered_elem == Cfg_last_read - 1))) begin + //start processing of next FM if reading is done already, or completes in the same cycle + Newest_buffered_elem <= -1; + Current_elem <= 0; + Window_buffer_read_addr_reg <= 0; + First_elem_next_window <= 0; + Fetching_done <= 0; + end else + Writing_done <= 1; + end + end + end + +endmodule : $TOP_MODULE_NAME$_impl diff --git a/finn-rtllib/swg/swg_template_wrapper_dynamic.v b/finn-rtllib/swg/swg_template_wrapper_dynamic.v new file mode 100644 index 0000000000000000000000000000000000000000..ca870ace11edcf097645bc12b0486ffbb83b0ea4 --- /dev/null +++ b/finn-rtllib/swg/swg_template_wrapper_dynamic.v @@ -0,0 +1,154 @@ +`timescale 1 ns / 1 ps + +module $TOP_MODULE_NAME$ #( + // top-level parameters (set via code-generation) + parameter BIT_WIDTH = $BIT_WIDTH$, + parameter SIMD = $SIMD$, + parameter MMV_IN = $MMV_IN$, + parameter MMV_OUT = $MMV_OUT$, + + parameter CNTR_BITWIDTH = $CNTR_BITWIDTH$, + parameter INCR_BITWIDTH = $INCR_BITWIDTH$, + + // derived constants + parameter BUF_IN_WIDTH = BIT_WIDTH * SIMD * MMV_IN, + parameter BUF_OUT_WIDTH = BIT_WIDTH * SIMD * MMV_OUT, + + parameter integer C_s_axilite_DATA_WIDTH = 32, + parameter integer C_s_axilite_ADDR_WIDTH = 6 +) +( + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out_V:s_axilite" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out_V:s_axilite" *) + input ap_rst_n, + input [BUF_IN_WIDTH-1:0] in0_V_TDATA, + input in0_V_TVALID, + output in0_V_TREADY, + output [BUF_OUT_WIDTH-1:0] out_V_TDATA, + output out_V_TVALID, + input out_V_TREADY, + + // Ports of Axi Slave Bus Interface s_axilite + input [C_s_axilite_ADDR_WIDTH-1 : 0] s_axilite_awaddr, + input [2 : 0] s_axilite_awprot, + input s_axilite_awvalid, + output s_axilite_awready, + input [C_s_axilite_DATA_WIDTH-1 : 0] s_axilite_wdata, + input [(C_s_axilite_DATA_WIDTH/8)-1 : 0] s_axilite_wstrb, + input s_axilite_wvalid, + output s_axilite_wready, + output [1 : 0] s_axilite_bresp, + output s_axilite_bvalid, + input s_axilite_bready, + input [C_s_axilite_ADDR_WIDTH-1 : 0] s_axilite_araddr, + input [2 : 0] s_axilite_arprot, + input s_axilite_arvalid, + output s_axilite_arready, + output [C_s_axilite_DATA_WIDTH-1 : 0] s_axilite_rdata, + output [1 : 0] s_axilite_rresp, + output s_axilite_rvalid, + input s_axilite_rready +); + +wire cfg_valid; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_simd; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_kw; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_kh; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_w; +wire [CNTR_BITWIDTH-1:0] cfg_cntr_h; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_simd; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_kw; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_kh; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_w; +wire [INCR_BITWIDTH-1:0] cfg_incr_head_h; +wire [INCR_BITWIDTH-1:0] cfg_incr_tail_w; +wire [INCR_BITWIDTH-1:0] cfg_incr_tail_h; +wire [INCR_BITWIDTH-1:0] cfg_incr_tail_last; +wire [31:0] cfg_last_read; +wire [31:0] cfg_last_write; + +// Instantiation of Axi Bus Interface s_axilite +$TOP_MODULE_NAME$_axilite # ( + .C_S_AXI_DATA_WIDTH(C_s_axilite_DATA_WIDTH), + .C_S_AXI_ADDR_WIDTH(C_s_axilite_ADDR_WIDTH) +) axilite_cfg_inst ( + .S_AXI_ACLK(ap_clk), + .S_AXI_ARESETN(ap_rst_n), + .S_AXI_AWADDR(s_axilite_awaddr), + .S_AXI_AWPROT(s_axilite_awprot), + .S_AXI_AWVALID(s_axilite_awvalid), + .S_AXI_AWREADY(s_axilite_awready), + .S_AXI_WDATA(s_axilite_wdata), + .S_AXI_WSTRB(s_axilite_wstrb), + .S_AXI_WVALID(s_axilite_wvalid), + .S_AXI_WREADY(s_axilite_wready), + .S_AXI_BRESP(s_axilite_bresp), + .S_AXI_BVALID(s_axilite_bvalid), + .S_AXI_BREADY(s_axilite_bready), + .S_AXI_ARADDR(s_axilite_araddr), + .S_AXI_ARPROT(s_axilite_arprot), + .S_AXI_ARVALID(s_axilite_arvalid), + .S_AXI_ARREADY(s_axilite_arready), + .S_AXI_RDATA(s_axilite_rdata), + .S_AXI_RRESP(s_axilite_rresp), + .S_AXI_RVALID(s_axilite_rvalid), + .S_AXI_RREADY(s_axilite_rready), + + .cfg_reg0(cfg_valid), + .cfg_reg1(cfg_cntr_simd), + .cfg_reg2(cfg_cntr_kw), + .cfg_reg3(cfg_cntr_kh), + .cfg_reg4(cfg_cntr_w), + .cfg_reg5(cfg_cntr_h), + .cfg_reg6(cfg_incr_head_simd), + .cfg_reg7(cfg_incr_head_kw), + .cfg_reg8(cfg_incr_head_kh), + .cfg_reg9(cfg_incr_head_w), + .cfg_reg10(cfg_incr_head_h), + .cfg_reg11(cfg_incr_tail_w), + .cfg_reg12(cfg_incr_tail_h), + .cfg_reg13(cfg_incr_tail_last), + .cfg_reg14(cfg_last_read), + .cfg_reg15(cfg_last_write) +); + +$TOP_MODULE_NAME$_impl +#( + .BIT_WIDTH(BIT_WIDTH), + .SIMD(SIMD), + .MMV_IN(MMV_IN), + .MMV_OUT(MMV_OUT), + .CNTR_BITWIDTH(CNTR_BITWIDTH), + .INCR_BITWIDTH(INCR_BITWIDTH) +) +impl +( + .ap_clk(ap_clk), + .ap_rst_n(ap_rst_n), + .in0_V_V_TDATA(in0_V_TDATA), + .in0_V_V_TVALID(in0_V_TVALID), + .in0_V_V_TREADY(in0_V_TREADY), + .out_V_V_TDATA(out_V_TDATA), + .out_V_V_TVALID(out_V_TVALID), + .out_V_V_TREADY(out_V_TREADY), + + .cfg_valid(cfg_valid), + .cfg_cntr_simd(cfg_cntr_simd), + .cfg_cntr_kw(cfg_cntr_kw), + .cfg_cntr_kh(cfg_cntr_kh), + .cfg_cntr_w(cfg_cntr_w), + .cfg_cntr_h(cfg_cntr_h), + .cfg_incr_head_simd(cfg_incr_head_simd), + .cfg_incr_head_kw(cfg_incr_head_kw), + .cfg_incr_head_kh(cfg_incr_head_kh), + .cfg_incr_head_w(cfg_incr_head_w), + .cfg_incr_head_h(cfg_incr_head_h), + .cfg_incr_tail_w(cfg_incr_tail_w), + .cfg_incr_tail_h(cfg_incr_tail_h), + .cfg_incr_tail_last(cfg_incr_tail_last), + .cfg_last_read(cfg_last_read), + .cfg_last_write(cfg_last_write) +); + +endmodule //TOP_MODULE_NAME diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py index 5424050a8ed0a353894721d5bba28c1d45e62771..1afd23d3a1709a8929a03c21a6eba0a5a8cd6ba6 100755 --- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py +++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py @@ -29,7 +29,6 @@ import math import numpy as np import os -from math import copysign from qonnx.core.datatype import DataType from qonnx.custom_op.general import im2col from qonnx.custom_op.general.im2col import compute_conv_output_dim @@ -81,6 +80,9 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): "inputDataType": ("s", True, ""), "outputDataType": ("s", True, ""), "depthwise": ("i", False, 0, {0, 1}), + # Enable reprogrammable implementation to change FM dimensions, + # stride, or dilation during runtime + "dynamic_mode": ("i", False, 0, {0, 1}), # FPGA resource type for ConvolutionInputGenerator input buffer # auto -- let Vivado decide # block -- use BRAM @@ -457,9 +459,11 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): def prepare_codegen_default(self): # Default implementation style for MMV_out = 1: addressable cyclic buffer # Computing incremental addressing scheme directly.. - template_path = ( - os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_default.sv" - ) + if self.get_nodeattr("dynamic_mode"): + template_select = "/finn-rtllib/swg/swg_template_default_dynamic.sv" + else: + template_select = "/finn-rtllib/swg/swg_template_default.sv" + template_path = os.environ["FINN_ROOT"] + template_select code_gen_dict = {} ifm_ch = self.get_nodeattr("IFMChannels") @@ -569,10 +573,6 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): tail_incr_last_window = buffer_min_size - 1 code_gen_dict["$IS_DEPTHWISE$"] = ["0"] - code_gen_dict["$TAIL_INCR_W$"] = [str(tail_incr_w)] - code_gen_dict["$TAIL_INCR_H$"] = [str(tail_incr_h)] - code_gen_dict["$TAIL_INCR_LAST$"] = [str(tail_incr_last_window)] - # support SIMD = IFMChannels and k_w = 1 cases # for k = [k_h, k_w] = [1, k_w], no adjustment is needed # for k = [k_h, k_w] = [1, 1], do not use this impl. style (mmv_out=K=1) @@ -590,11 +590,23 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_SIMD"] loop_simd_iterations -= 1 # -1 because state is initial state - code_gen_dict["$LOOP_H_ITERATIONS$"] = [str(loop_h_iterations - 1)] - code_gen_dict["$LOOP_W_ITERATIONS$"] = [str(loop_w_iterations - 1)] - code_gen_dict["$LOOP_KH_ITERATIONS$"] = [str(loop_kh_iterations - 1)] - code_gen_dict["$LOOP_KW_ITERATIONS$"] = [str(loop_kw_iterations - 1)] - code_gen_dict["$LOOP_SIMD_ITERATIONS$"] = [str(loop_simd_iterations - 1)] + cntr_bitwidth = math.ceil( + math.log2( + max( + loop_h_iterations - 2 + 1, + loop_w_iterations - 2 + 1, + loop_kh_iterations - 2 + 1, + loop_kw_iterations - 2 + 1, + loop_simd_iterations - 2 + 1, + ) + ) + ) + code_gen_dict["$CNTR_BITWIDTH$"] = [str(cntr_bitwidth)] + code_gen_dict["$LOOP_H_ITERATIONS$"] = [str(loop_h_iterations - 2)] + code_gen_dict["$LOOP_W_ITERATIONS$"] = [str(loop_w_iterations - 2)] + code_gen_dict["$LOOP_KH_ITERATIONS$"] = [str(loop_kh_iterations - 2)] + code_gen_dict["$LOOP_KW_ITERATIONS$"] = [str(loop_kw_iterations - 2)] + code_gen_dict["$LOOP_SIMD_ITERATIONS$"] = [str(loop_simd_iterations - 2)] incr_bitwidth = 1 + math.ceil( math.log2( @@ -611,21 +623,14 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): ) ) code_gen_dict["$INCR_BITWIDTH$"] = [str(incr_bitwidth)] - code_gen_dict["$ADDR_INCREMENT_MAP$"] = [ - "'{{ {}'d0, {}'d{}, {}'d{}, {}'d{}, {}'d{}, {}'d{}}}".format( - incr_bitwidth, - int(copysign(incr_bitwidth, addr_incr_end_simd)), - abs(addr_incr_end_simd), - int(copysign(incr_bitwidth, addr_incr_end_window_elem)), - abs(addr_incr_end_window_elem), - int(copysign(incr_bitwidth, addr_incr_end_window_row)), - abs(addr_incr_end_window_row), - int(copysign(incr_bitwidth, addr_incr_end_window)), - abs(addr_incr_end_window), - int(copysign(incr_bitwidth, addr_incr_end_row)), - abs(addr_incr_end_row), - ) - ] + code_gen_dict["$HEAD_INCR_SIMD$"] = [str(addr_incr_end_simd)] + code_gen_dict["$HEAD_INCR_KW$"] = [str(addr_incr_end_window_elem)] + code_gen_dict["$HEAD_INCR_KH$"] = [str(addr_incr_end_window_row)] + code_gen_dict["$HEAD_INCR_W$"] = [str(addr_incr_end_window)] + code_gen_dict["$HEAD_INCR_H$"] = [str(addr_incr_end_row)] + code_gen_dict["$TAIL_INCR_W$"] = [str(tail_incr_w)] + code_gen_dict["$TAIL_INCR_H$"] = [str(tail_incr_h)] + code_gen_dict["$TAIL_INCR_LAST$"] = [str(tail_incr_last_window)] code_gen_dict["$ELEM_PER_WINDOW$"] = [str(elem_per_window)] code_gen_dict["$SIMD$"] = [str(simd)] @@ -710,15 +715,22 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") with open(template_path, "r") as f: template = f.read() + if self.get_nodeattr("dynamic_mode"): + template_select = "/finn-rtllib/swg/swg_template_wrapper_dynamic.v" + else: + template_select = "/finn-rtllib/swg/swg_template_wrapper.v" + with open(os.environ["FINN_ROOT"] + template_select, "r") as f: + template_wrapper = f.read() with open( - os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_wrapper.v", "r" + os.environ["FINN_ROOT"] + "/finn-rtllib/swg/swg_template_axilite.v", "r" ) as f: - template_wrapper = f.read() + template_axilite = f.read() for key in code_gen_dict: # transform list into long string separated by '\n' code_gen_line = "\n".join(code_gen_dict[key]) template = template.replace(key, code_gen_line) template_wrapper = template_wrapper.replace(key, code_gen_line) + template_axilite = template_axilite.replace(key, code_gen_line) with open( os.path.join( code_gen_dir, self.get_nodeattr("gen_top_module") + "_impl.sv" @@ -734,6 +746,16 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): ) as f: f.write(template_wrapper) + # AXI-Lite reg. file component is only needed for dynamic mode + if self.get_nodeattr("dynamic_mode"): + with open( + os.path.join( + code_gen_dir, self.get_nodeattr("gen_top_module") + "_axilite.v" + ), + "w", + ) as f: + f.write(template_axilite) + # set ipgen_path and ip_path so that HLS-Synth transformation # and stich_ip transformation do not complain self.set_nodeattr("ipgen_path", code_gen_dir) @@ -754,6 +776,8 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): self.get_nodeattr("gen_top_module") + "_wrapper.v", self.get_nodeattr("gen_top_module") + "_impl.sv", ] + if self.get_nodeattr("dynamic_mode"): + verilog_files.append(self.get_nodeattr("gen_top_module") + "_axilite.v") # build the Verilator emu library sim = PyVerilator.build( @@ -771,25 +795,97 @@ class ConvolutionInputGenerator_rtl(HLSCustomOp): """Constructs and returns the TCL for node instantiation in Vivado IPI.""" code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") - cmd = [ - "add_files -norecurse %s" - % ( - os.path.join( - code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v" - ) - ), - "add_files -norecurse %s" - % ( - os.path.join( - code_gen_dir, self.get_nodeattr("gen_top_module") + "_impl.sv" - ) - ), - "create_bd_cell -type module -reference %s %s" - % (self.get_nodeattr("gen_top_module"), self.onnx_node.name), + sourcefiles = [ + self.get_nodeattr("gen_top_module") + "_wrapper.v", + self.get_nodeattr("gen_top_module") + "_impl.sv", ] + if self.get_nodeattr("dynamic_mode"): + sourcefiles += [self.get_nodeattr("gen_top_module") + "_axilite.v"] + + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % (f)] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] return cmd + def get_verilog_top_module_intf_names(self): + # Overload default HLSCustomOp implementation to add axilite control IF + """Return a dict of names of input and output interfaces. + The keys reflect the protocols each interface implements: + 'clk', 'rst', 'm_axis', 's_axis', 'aximm', 'axilite'. + Values are lists of tuples (axis, aximm) or names (axilite): + 'axis' tuples correspond to the list of node inputs in order, + each tuple is (interface_name, interface_width_bits). + axilite always assumed to be 32 bits and is not tuple (name only). + Each block must have at most one aximm and one axilite.""" + intf_names = super().get_verilog_top_module_intf_names() + if self.get_nodeattr("dynamic_mode"): + intf_names["axilite"] = ["s_axilite"] + return intf_names + + def get_dynamic_config(self, ifm_dim=None, stride=None, dilation=None): + """Returns a configuration dict to re-configure FM dimension during + runtime. Stride and dilation can also be changed. Certain restrictions + apply (e.g. component must be synthesized for largest buffer size).""" + # NOTE: For better driver integration, this functionality could be packaged + # as a standalone function in the future + + if ifm_dim is None: + ifm_dim = self.get_nodeattr("IFMDim") + k = self.get_nodeattr("ConvKernelDim") + if stride is None: + stride = self.get_nodeattr("Stride") + if dilation is None: + dilation = self.get_nodeattr("Dilation") + + k_h, k_w = k + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ifm_dim_h, ifm_dim_w = ifm_dim + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) + ofm_dim = [ofm_dim_h, ofm_dim_w] + + # update attributes and perform sanity check + original_buffer_depth = self.get_buffer_depth() + self.set_nodeattr("IFMDim", ifm_dim) + self.set_nodeattr("OFMDim", ofm_dim) + self.set_nodeattr("Stride", stride) + self.set_nodeattr("Dilation", dilation) + assert ( + self.get_buffer_depth() <= original_buffer_depth + ), """Error: requested + dynamic configuration does not fit in generated buffer implementation.""" + + # (re-)call codegen and extract new values + # each setting is mapped to an axi-lite register address + template_path, code_gen_dict = self.prepare_codegen_default() + config = { + "cfg_wren": (0 * 4, 1), + "cfg_cntr_simd": (1 * 4, int(code_gen_dict["$LOOP_SIMD_ITERATIONS$"][0])), + "cfg_cntr_kw": (2 * 4, int(code_gen_dict["$LOOP_KW_ITERATIONS$"][0])), + "cfg_cntr_kh": (3 * 4, int(code_gen_dict["$LOOP_KH_ITERATIONS$"][0])), + "cfg_cntr_w": (4 * 4, int(code_gen_dict["$LOOP_W_ITERATIONS$"][0])), + "cfg_cntr_h": (5 * 4, int(code_gen_dict["$LOOP_H_ITERATIONS$"][0])), + "cfg_incr_head_simd": (6 * 4, int(code_gen_dict["$HEAD_INCR_SIMD$"][0])), + "cfg_incr_head_kw": (7 * 4, int(code_gen_dict["$HEAD_INCR_KW$"][0])), + "cfg_incr_head_kh": (8 * 4, int(code_gen_dict["$HEAD_INCR_KH$"][0])), + "cfg_incr_head_w": (9 * 4, int(code_gen_dict["$HEAD_INCR_W$"][0])), + "cfg_incr_head_h": (10 * 4, int(code_gen_dict["$HEAD_INCR_H$"][0])), + "cfg_incr_tail_w": (11 * 4, int(code_gen_dict["$TAIL_INCR_W$"][0])), + "cfg_incr_tail_h": (12 * 4, int(code_gen_dict["$TAIL_INCR_H$"][0])), + "cfg_incr_tail_last": (13 * 4, int(code_gen_dict["$TAIL_INCR_LAST$"][0])), + "cfg_last_read": (14 * 4, int(code_gen_dict["$LAST_READ_ELEM$"][0])), + "cfg_last_write": (15 * 4, int(code_gen_dict["$LAST_WRITE_ELEM$"][0])), + } + return config + def code_generation_ipgen(self, model, fpgapart, clk): """Normally: Generates C++ code and tcl script for IP generation. Here: Generates (System-)Verilog code for IP generation.""" diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 3af34eba8eb709099474426b665f295f21e0ce40..73df52f890d227137ea076804d161206e66653dc 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -492,6 +492,8 @@ class AbsorbConsecutiveTransposes(Transformation): if node.op_type == "Transpose": next_nodes = model.find_consumers(node.output[0]) perms1 = list(get_by_name(node.attribute, "perm").ints) + if len(next_nodes) == 0: + continue # check if all nodes after fork are opposite transposes all_opposite_transposes = True for next_node in next_nodes: diff --git a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..cd20b305a15073fb6499727e09a4acac94ad5c89 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py @@ -0,0 +1,495 @@ +# Copyright (c) 2022, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import copy +import numpy as np +import onnx.parser as oprs +import os +from onnx import TensorProto, helper +from pyverilator.util.axi_utils import axilite_write, reset_rtlsim +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.general.im2col import compute_conv_output_dim +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames +from qonnx.transformation.infer_datatypes import InferDataTypes +from qonnx.transformation.infer_shapes import InferShapes +from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul +from qonnx.util.basic import gen_finn_dt_tensor, get_by_name + +import finn.core.onnx_exec as oxe +import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls +import finn.transformation.streamline.absorb as absorb +from finn.core.onnx_exec import execute_onnx +from finn.core.rtlsim_exec import rtlsim_exec +from finn.transformation.fpgadataflow.create_dataflow_partition import ( + CreateDataflowPartition, +) +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.util.basic import pyverilate_get_liveness_threshold_cycles + + +def create_conv_model(idim, ifm, k, stride, ofm, idt, wdt): + np.random.seed(0) + ishp = (1, ifm, idim, idim) + int_dim = compute_conv_output_dim(idim, k, stride) + odim = compute_conv_output_dim(int_dim, k, stride) + oshp = (1, ofm, odim, odim) + wshp = (ofm, ifm, k, k) + wshp_1 = (ofm, ofm, k, k) + ishp_str = str(list(ishp)) + oshp_str = str(list(oshp)) + wshp_str = str(list(wshp)) + wshp_1_str = str(list(wshp_1)) + kshp_str = str([k, k]) + pad_str = str([0, 0, 0, 0]) + stride_str = str([stride, stride]) + dil_str = str([1, 1]) + + input = f""" + < + ir_version: 7, + opset_import: ["" : 9] + > + agraph (float{ishp_str} in0) => (float{oshp_str} out0) + < + float{wshp_str} param_c0_weight, + float{wshp_1_str} param_c1_weight + > + {{ + conv0 = Conv< + dilations={dil_str},group=1,kernel_shape={kshp_str},pads={pad_str}, + strides={stride_str} + >(in0, param_c0_weight) + out0 = Conv< + dilations={dil_str},group=1,kernel_shape={kshp_str},pads={pad_str}, + strides={stride_str} + >(conv0, param_c1_weight) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + model.set_tensor_datatype("in0", idt) + model.set_tensor_datatype("param_c0_weight", wdt) + model.set_tensor_datatype("param_c1_weight", wdt) + model.set_initializer("param_c0_weight", gen_finn_dt_tensor(wdt, wshp)) + model.set_initializer("param_c1_weight", gen_finn_dt_tensor(wdt, wshp_1)) + return model + + +def update_conv_model_dims(model, idim_new): + cnode = model.get_nodes_by_op_type("Conv")[0] + k, _ = get_by_name(cnode.attribute, "kernel_shape").ints + stride, _ = get_by_name(cnode.attribute, "strides").ints + ishp = model.get_tensor_shape("in0") + n, ci, _, _ = ishp + n, co, _, _ = model.get_tensor_shape("out0") + int_dim = compute_conv_output_dim(idim_new, k, stride) + odim = compute_conv_output_dim(int_dim, k, stride) + model.set_tensor_shape("in0", (n, ci, idim_new, idim_new)) + model.set_tensor_shape("out0", (n, co, odim, odim)) + # remove all existing shapes + del model.graph.value_info[:] + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return model + + +# Helper function to update tensor dimensions manually because shape inference +# does not work on FINN nodes (they assume well-defined tensor shapes). +def update_tensor_dim(model, tensor_name, new_hw): + shape = model.get_tensor_shape(tensor_name) + shape[1] = new_hw[0] + shape[2] = new_hw[1] + model.set_tensor_shape(tensor_name, shape) + + +# Helper function that delivers the hook to program the SWG via AXI-Lite +def config_hook(configs): + if configs is None: + return None + + def write_swg_config(sim): + for axi_name, config in configs: + # 1. Write config registers to the SWG, dict defines (addr, value) tuples + for config_entry in config.values(): + axilite_write(sim, config_entry[0], config_entry[1], basename=axi_name) + # 2. Set cfg_valid flag (>= 1 cycle) + axilite_write(sim, 0, 1, basename=axi_name) + # 3. Reset component (>= 1 cycle) + reset_rtlsim(sim) + + return write_swg_config + + +@pytest.mark.slow +@pytest.mark.vivado +@pytest.mark.fpgadataflow +def test_fpgadataflow_conv_dynamic(): + idims = [32, 16] + ifm = 4 + k = 4 + stride = 1 + ofm = 8 + idt = DataType["UINT8"] + wdt = DataType["INT2"] + exp_cfgs = [] + largest_model = None + for idim in idims: + ishp = (1, ifm, idim, idim) + np.random.seed(0) + inp = gen_finn_dt_tensor(idt, ishp) + model = create_conv_model(idim, ifm, k, stride, ofm, idt, wdt) + _, _, int_dim, _ = model.get_tensor_shape("conv0") + _, _, odim, _ = model.get_tensor_shape("out0") + if idim == max(idims): + # use largest model for hardware conversion + largest_model = copy.deepcopy(model) + golden = execute_onnx(model, {"in0": inp})["out0"] + exp_cfg = (idim, int_dim, odim, inp, golden) + exp_cfgs.append(exp_cfg) + + # convert to hardware and prepare simulation + model = largest_model.transform(LowerConvsToMatMul()) + model = model.transform(to_hls.InferConvInpGen(use_rtl_variant=True)) + model = model.transform( + to_hls.InferQuantizedMatrixVectorActivation(mem_mode="decoupled") + ) + model = model.transform(absorb.AbsorbConsecutiveTransposes()) + parent_model = model.transform(CreateDataflowPartition()) + sdp_inst = getCustomOp( + parent_model.get_nodes_by_op_type("StreamingDataflowPartition")[0] + ) + model = ModelWrapper(sdp_inst.get_nodeattr("model")) + for swg_node in model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl"): + getCustomOp(swg_node).set_nodeattr("SIMD", 1) + getCustomOp(swg_node).set_nodeattr("dynamic_mode", 1) + getCustomOp(swg_node).set_nodeattr("inFIFODepths", [16]) + getCustomOp(swg_node).set_nodeattr("outFIFODepths", [16]) + model = model.transform(InsertFIFO()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + model = model.transform(PrepareIP("xc7z020clg400-1", 5)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP("xc7z020clg400-1", 5)) + model.set_metadata_prop("exec_mode", "rtlsim") + + # loop through experiment configurations + for exp_cfg in exp_cfgs: + idim, int_dim, odim, inp, golden = exp_cfg + # get config for the new dimensions + swg_nodes = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl") + swg0 = getCustomOp(swg_nodes[0]) + update_tensor_dim(model, swg0.onnx_node.input[0], (idim, idim)) + update_tensor_dim(model, swg0.onnx_node.output[0], (int_dim, int_dim)) + config0 = swg0.get_dynamic_config((idim, idim)) + swg1 = getCustomOp(swg_nodes[1]) + update_tensor_dim(model, swg1.onnx_node.input[0], (int_dim, int_dim)) + update_tensor_dim(model, swg1.onnx_node.output[0], (odim, odim)) + config1 = swg1.get_dynamic_config((int_dim, int_dim)) + configs = [("s_axilite_0_", config0), ("s_axilite_1_", config1)] + # adjust folded shapes for I/O FIFOs + # (since rtlsim_exec uses folded shape info to fold global i/o tensors) + first_node = getCustomOp(model.graph.node[0]) + first_node_shp = list(first_node.get_folded_input_shape()) + first_node_shp[1] = idim + first_node_shp[2] = idim + first_node.set_nodeattr("folded_shape", first_node_shp) + update_tensor_dim(model, first_node.onnx_node.input[0], (idim, idim)) + last_node = getCustomOp(model.graph.node[-1]) + last_node_shp = list(last_node.get_folded_output_shape()) + last_node_shp[1] = odim + last_node_shp[2] = odim + update_tensor_dim(model, last_node.onnx_node.output[0], (odim, odim)) + last_node.set_nodeattr("folded_shape", last_node_shp) + ctx = {"global_in": inp.transpose(0, 2, 3, 1)} + liveness_prev = pyverilate_get_liveness_threshold_cycles() + os.environ["LIVENESS_THRESHOLD"] = "100000" + rtlsim_exec(model, ctx, pre_hook=config_hook(configs)) + os.environ["LIVENESS_THRESHOLD"] = str(liveness_prev) + ret = ctx["global_out"].transpose(0, 3, 1, 2) + assert np.isclose(golden, ret).all() + + +def make_single_im2col_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, stride, dilation, idt): + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h, ofm_dim_w = ofm_dim + + odt = idt + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_dim_h, ifm_dim_w, ifm_ch] + ) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch] + ) + + im2col_node = helper.make_node( + "Im2Col", + ["inp"], + ["outp"], + domain="finn.custom_op.general", + stride=[stride_h, stride_w], + kernel_size=[k_h, k_w], + input_shape=str((1, ifm_dim_h, ifm_dim_w, ifm_ch)), + dilations=[dilation_h, dilation_w], + pad_amount=[0, 0, 0, 0], + pad_value=0, + ) + graph = helper.make_graph( + nodes=[im2col_node], name="im2col_graph", inputs=[inp], outputs=[outp] + ) + + model = helper.make_model(graph, producer_name="im2col-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + + return model + + +def make_single_slidingwindow_modelwrapper( + k, ifm_ch, ifm_dim, ofm_dim, simd, m, parallel_window, stride, dilation, idt, dw=0 +): + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h, ofm_dim_w = ofm_dim + + odt = idt + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_dim_h, ifm_dim_w, ifm_ch] + ) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch] + ) + + SlidingWindow_node = helper.make_node( + "ConvolutionInputGenerator_rtl", + ["inp"], + ["outp"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ifm_dim_h, ifm_dim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], + SIMD=simd, + M=m, + parallel_window=parallel_window, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], + inputDataType=idt.name, + outputDataType=odt.name, + depthwise=dw, + dynamic_mode=1, + ) + graph = helper.make_graph( + nodes=[SlidingWindow_node], + name="slidingwindow_graph", + inputs=[inp], + outputs=[outp], + ) + + model = helper.make_model(graph, producer_name="slidingwindow-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + + return model + + +def prepare_inputs(input_tensor): + return {"inp": input_tensor} + + +# input datatype +@pytest.mark.parametrize("idt", [DataType["UINT4"]]) +# kernel size +@pytest.mark.parametrize("k", [[3, 3]]) +# input dimension +@pytest.mark.parametrize("ifm_dim_series", [[[32, 32], [16, 16], [8, 8]]]) +# input channels +@pytest.mark.parametrize("ifm_ch", [6]) +# Stride +@pytest.mark.parametrize("stride", [[1, 1]]) +# Dilation +@pytest.mark.parametrize("dilation", [[1, 1]]) +# depthwise +@pytest.mark.parametrize("dw", [0, 1]) +# input channel parallelism ("SIMD") +@pytest.mark.parametrize("simd", [2, 6]) +# parallel_window enable (MMV_out = M*K) +@pytest.mark.parametrize("parallel_window", [0]) +# in/out MMV ("M") +@pytest.mark.parametrize("m", [1]) +@pytest.mark.slow +@pytest.mark.vivado +@pytest.mark.fpgadataflow +def test_fpgadataflow_slidingwindow_rtl_dynamic( + idt, k, ifm_dim_series, ifm_ch, stride, dilation, dw, simd, m, parallel_window +): + # Begin test by generating RTL SWG normally for the first FM of the series. + # The following FM dimensions must be equal or smaller than the initial + # dimensions (in terms of required buffer depth). + ifm_dim = ifm_dim_series[0] + + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) + ofm_dim = [ofm_dim_h, ofm_dim_w] + kernel_width = (k_w - 1) * dilation_w + 1 # incl. dilation + kernel_height = (k_h - 1) * dilation_h + 1 # incl. dilation + + if simd > ifm_ch: + pytest.skip("SIMD cannot be larger than number of input channels") + if ifm_ch % simd != 0: + pytest.skip("SIMD must divide number of input channels") + if kernel_height > ifm_dim_h or stride_h > ifm_dim_h: + pytest.skip( + "Illegal convolution configuration: kernel or stride > FM dimension" + ) + if kernel_width > ifm_dim_w or stride_w > ifm_dim_w: + pytest.skip( + "Illegal convolution configuration: kernel or stride > FM dimension" + ) + if (k_h == 1 and (stride_h != 1 or dilation_h != 1)) or ( + k_w == 1 and (stride_w != 1 or dilation_w != 1) + ): + pytest.skip( + """Illegal convolution configuration: + stride or dilation defined for unitary kernel dim""" + ) + if k_h == 1 and k_w == 1 and simd != ifm_ch: + pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)") + if parallel_window and simd != ifm_ch: + pytest.skip("Parallel window requires SIMD=C") + + model = make_single_slidingwindow_modelwrapper( + k=k, + ifm_ch=ifm_ch, + ifm_dim=ifm_dim, + ofm_dim=ofm_dim, + simd=simd, + m=m, + parallel_window=parallel_window, + stride=stride, + dilation=dilation, + idt=idt, + dw=dw, + ) + + # Simulate using stitched-ip-rtlsim so we can use existing infrastructure + # that supports hook functions to re-program configuration before rtlsim + model = model.transform(InsertFIFO(True)) # required for proper simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xc7z020clg400-1", 5)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP("xc7z020clg400-1", 5)) + model.set_metadata_prop("exec_mode", "rtlsim") + + # Simulate 1 FM for each dimension in the series + for i, ifm_dim in enumerate(ifm_dim_series): + ifm_dim_h, ifm_dim_w = ifm_dim + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) + ofm_dim = [ofm_dim_h, ofm_dim_w] + + configs = None + if i > 0: # skip re-programming for initial FM dimension + # Necessary update of node and tensor attributes to make rtlsim work: + swg_node = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")[0] + swg_inst = getCustomOp(swg_node) + update_tensor_dim(model, swg_node.input[0], ifm_dim) + update_tensor_dim(model, swg_node.output[0], ofm_dim) + + # Generate config, also overwrites IFMDim/OFMDim attributes: + config = swg_inst.get_dynamic_config(ifm_dim) + configs = [("s_axilite_0_", config)] + + # Also update FIFO nodes and corresponding tensors + fifo_node = model.get_nodes_by_op_type("StreamingFIFO")[0] + fifo_inst = getCustomOp(fifo_node) + shape = fifo_inst.get_nodeattr("folded_shape") + shape[1] = ifm_dim_h + shape[2] = ifm_dim_w + fifo_inst.set_nodeattr("folded_shape", shape) + update_tensor_dim(model, fifo_node.input[0], ifm_dim) + + fifo_node = model.get_nodes_by_op_type("StreamingFIFO")[1] + fifo_inst = getCustomOp(fifo_node) + shape = fifo_inst.get_nodeattr("folded_shape") + shape[1] = ofm_dim_h + shape[2] = ofm_dim_w + fifo_inst.set_nodeattr("folded_shape", shape) + update_tensor_dim(model, fifo_node.output[0], ofm_dim) + + # Run rtlsim on stitched-ip + x = gen_finn_dt_tensor(idt, (1, ifm_dim_h, ifm_dim_w, ifm_ch)) + context = prepare_inputs(x) + rtlsim_exec(model, context, pre_hook=config_hook(configs)) + y_produced = context["outp"] + + # Generate golden result + golden = make_single_im2col_modelwrapper( + k=k, + ifm_ch=ifm_ch, + ifm_dim=ifm_dim, + ofm_dim=ofm_dim, + stride=stride, + dilation=dilation, + idt=idt, + ) + input_dict = prepare_inputs(x) + y_expected = oxe.execute_onnx(golden, input_dict)["outp"] + + # Check result + if dw == 0: + assert (y_produced == y_expected).all() + else: + y_expected = y_expected.reshape( + 1, ofm_dim_h, ofm_dim_w, k_h * k_w, ifm_ch // simd, simd + ) + y_expected = y_expected.transpose(0, 1, 2, 4, 3, 5) + y_expected = y_expected.reshape(1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w) + assert (y_produced == y_expected).all()