Skip to content
Snippets Groups Projects
Commit f46e2d0a authored by Felix Jentzsch's avatar Felix Jentzsch
Browse files

Restructure, basic resource estimation

parent b5dccf0f
No related branches found
No related tags found
No related merge requests found
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
module $TOP_MODULE_NAME$_controller module $TOP_MODULE_NAME$_controller
( (
CLK, CLK,
cycle, RST,
advance,
cmd_read, cmd_read,
cmd_write cmd_write
); );
input CLK; input CLK;
input [31:0] cycle; //todo: minimize width or switch to single bit flag input RST;
input advance;
output cmd_read; output cmd_read;
output cmd_write; output cmd_write;
...@@ -39,10 +41,6 @@ integer counter_loop_inter; ...@@ -39,10 +41,6 @@ integer counter_loop_inter;
assign cmd_read = READ_CMD_MAP[state_next]; //read command indicates read in *upcoming* cycle, due to how schedule is constructed assign cmd_read = READ_CMD_MAP[state_next]; //read command indicates read in *upcoming* cycle, due to how schedule is constructed
assign cmd_write = WRITE_CMD_MAP[state]; assign cmd_write = WRITE_CMD_MAP[state];
reg cycle_last;
wire cycle_advance;
assign cycle_advance = !(cycle == cycle_last);
//combinational next state logic //combinational next state logic
always @ (state, counter_current, counter_loop_main, counter_loop_inter) begin always @ (state, counter_current, counter_loop_main, counter_loop_inter) begin
state_next = state; //default state_next = state; //default
...@@ -67,7 +65,7 @@ always @ (state, counter_current, counter_loop_main, counter_loop_inter) begin ...@@ -67,7 +65,7 @@ always @ (state, counter_current, counter_loop_main, counter_loop_inter) begin
if (LOOP_END_1_COUNTER != 0) if (LOOP_END_1_COUNTER != 0)
state_next = STATE_END_1; state_next = STATE_END_1;
else else
state_next = STATE_START; state_next = STATE_LOOP_MAIN_2; //wait in current state until reset
end end
end end
end end
...@@ -91,49 +89,46 @@ always @ (state, counter_current, counter_loop_main, counter_loop_inter) begin ...@@ -91,49 +89,46 @@ always @ (state, counter_current, counter_loop_main, counter_loop_inter) begin
if (LOOP_END_2_COUNTER != 0) if (LOOP_END_2_COUNTER != 0)
state_next = STATE_END_2; state_next = STATE_END_2;
else else
state_next = STATE_START; state_next = STATE_END_1; //wait in current state until reset
end end
end end
STATE_END_2: STATE_END_2:
if (counter_current == LOOP_END_2_COUNTER-1) if (counter_current == LOOP_END_2_COUNTER-1)
state_next = STATE_START; state_next = STATE_END_2; //wait in current state until reset
endcase endcase
end end
//sequential logic //sequential logic
always @ (posedge CLK) begin always @ (posedge CLK) begin
if (cycle == 0) begin if (RST) begin
counter_current <= 0; counter_current <= -1;
counter_loop_main <= 0; counter_loop_main <= 0;
counter_loop_inter <= 0; counter_loop_inter <= 0;
cycle_last <= 0;
state <= STATE_START; state <= STATE_START;
end else begin end else begin
cycle_last <= cycle; if (advance) begin
state <= state_next;
if (cycle_advance) begin
counter_current <= counter_current+1; counter_current <= counter_current+1;
end state <= state_next;
if (state != state_next) begin if (state != state_next) begin
counter_current <= 0; counter_current <= 0;
//count up main loop upon re-entering this loop (not on first enter from start) //count up main loop upon re-entering this loop (not on first enter from start)
if ((state_next == STATE_LOOP_MAIN_1) && (state != STATE_START)) begin if ((state_next == STATE_LOOP_MAIN_1) && (state != STATE_START)) begin
if (counter_loop_main == LOOP_MAIN_COUNTER-1) begin if (counter_loop_main == LOOP_MAIN_COUNTER-1) begin
counter_loop_main <= 0; counter_loop_main <= 0;
end else begin end else begin
counter_loop_main <= counter_loop_main+1; counter_loop_main <= counter_loop_main+1;
end
end end
end
if (state_next == STATE_LOOP_INTER_1) begin if (state_next == STATE_LOOP_INTER_1) begin
if (counter_loop_inter == LOOP_INTER_COUNTER) begin //no -1 because this counter marks the currently active iteration, not finished iterations if (counter_loop_inter == LOOP_INTER_COUNTER) begin //no -1 because this counter marks the currently active iteration, not finished iterations
counter_loop_inter <= 0; counter_loop_inter <= 0;
end else begin end else begin
counter_loop_inter <= counter_loop_inter+1; counter_loop_inter <= counter_loop_inter+1;
end
end end
end end
end end
...@@ -169,8 +164,8 @@ output [WIDTH*DEPTH-1:0] data_out; ...@@ -169,8 +164,8 @@ output [WIDTH*DEPTH-1:0] data_out;
// File: shift_registers_1.v // File: shift_registers_1.v
// //
//module shift_registers_1 (clk, clken, SI, SO); //module shift_registers_1 (clk, clken, SI, SO);
//parameter WIDTH = 32; //parameter WIDTH = 32;
//input clk, clken, SI; //input clk, clken, SI;
//output SO; //output SO;
//reg [WIDTH-1:0] shreg; //reg [WIDTH-1:0] shreg;
// //
...@@ -181,7 +176,7 @@ output [WIDTH*DEPTH-1:0] data_out; ...@@ -181,7 +176,7 @@ output [WIDTH*DEPTH-1:0] data_out;
// begin // begin
// for (i = 0; i < WIDTH-1; i = i+1) // for (i = 0; i < WIDTH-1; i = i+1)
// shreg[i+1] <= shreg[i]; // shreg[i+1] <= shreg[i];
// shreg[0] <= SI; // shreg[0] <= SI;
// end // end
//end //end
//assign SO = shreg[WIDTH-1]; //assign SO = shreg[WIDTH-1];
...@@ -227,7 +222,7 @@ integer addr_w, addr_r; //todo: minimize width (as reg), make r addr depend on w ...@@ -227,7 +222,7 @@ integer addr_w, addr_r; //todo: minimize width (as reg), make r addr depend on w
$RAM_STYLE$ reg [WIDTH-1:0] ram [DEPTH-1:0]; $RAM_STYLE$ reg [WIDTH-1:0] ram [DEPTH-1:0];
always @(posedge CLK) begin always @(posedge CLK) begin
if (RST == 1'b0) begin if (RST == 1'b0) begin
addr_w <= 0; addr_w <= 0;
addr_r <= 1; addr_r <= 1;
...@@ -349,11 +344,15 @@ wire read_cmd; ...@@ -349,11 +344,15 @@ wire read_cmd;
wire write_cmd; wire write_cmd;
reg write_done; //keep track if W of current cycle was already completed, but we still wait on a R in the same cycle reg write_done; //keep track if W of current cycle was already completed, but we still wait on a R in the same cycle
wire controller_reset;
wire controller_advance;
$TOP_MODULE_NAME$_controller $TOP_MODULE_NAME$_controller
controller_inst controller_inst
( (
.CLK(ap_clk), .CLK(ap_clk),
.cycle(cycle), .RST(controller_reset),
.advance(controller_advance),
.cmd_read(read_cmd), .cmd_read(read_cmd),
.cmd_write(write_cmd) .cmd_write(write_cmd)
); );
...@@ -379,6 +378,9 @@ assign advance = read_ok || (!read_cmd && write_ok) || (!read_c ...@@ -379,6 +378,9 @@ assign advance = read_ok || (!read_cmd && write_ok) || (!read_c
//todo: if mmv_out < k: might not shift and/or write for multiple read_cmd cycles //todo: if mmv_out < k: might not shift and/or write for multiple read_cmd cycles
assign window_buffer_shift_enable = advance; assign window_buffer_shift_enable = advance;
assign controller_reset = !ap_rst_n || ((cycle == CYCLES_TOTAL-1) && advance);
assign controller_advance = advance;
//assign I/O ports //assign I/O ports
assign window_buffer_in = in0_V_V_TDATA; assign window_buffer_in = in0_V_V_TDATA;
assign out_V_V_TDATA = window_buffer_out; assign out_V_V_TDATA = window_buffer_out;
......
...@@ -30,22 +30,21 @@ import pytest ...@@ -30,22 +30,21 @@ import pytest
import numpy as np import numpy as np
from onnx import TensorProto, helper from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe
from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer
from qonnx.core.datatype import DataType from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.general.im2col import compute_conv_output_dim from qonnx.custom_op.general.im2col import compute_conv_output_dim
from qonnx.custom_op.registry import getCustomOp from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.general import GiveUniqueNodeNames from qonnx.transformation.general import GiveUniqueNodeNames
from qonnx.util.basic import gen_finn_dt_tensor from qonnx.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe
from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
def make_single_im2col_modelwrapper(
k, ifm_ch, ifm_dim, ofm_dim, stride, dilation, idt def make_single_im2col_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, stride, dilation, idt):
):
k_h, k_w = k k_h, k_w = k
ifm_dim_h, ifm_dim_w = ifm_dim ifm_dim_h, ifm_dim_w = ifm_dim
stride_h, stride_w = stride stride_h, stride_w = stride
...@@ -134,10 +133,10 @@ def make_single_slidingwindow_modelwrapper( ...@@ -134,10 +133,10 @@ def make_single_slidingwindow_modelwrapper(
model.set_tensor_datatype("inp", idt) model.set_tensor_datatype("inp", idt)
model.set_tensor_datatype("outp", odt) model.set_tensor_datatype("outp", odt)
#DEBUG # DEBUG
swg_node = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")[0] # swg_node = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")[0]
swg_inst = getCustomOp(swg_node) # swg_inst = getCustomOp(swg_node)
swg_inst.set_nodeattr("rtlsim_trace", "/workspace/finn/finn-rtllib/swg/swg_test_trace.vcd") # swg_inst.set_nodeattr("rtlsim_trace", "/home/felixj/WD/finn/finn-rtllib/swg/swg_test_trace.vcd")
return model return model
...@@ -159,39 +158,46 @@ def prepare_inputs(input_tensor): ...@@ -159,39 +158,46 @@ def prepare_inputs(input_tensor):
# ], # ],
# ) # )
# kernel size # kernel size
@pytest.mark.parametrize("k", [[1,1],[2,2],[3,3],[4,5],[1,3]]) @pytest.mark.parametrize("k", [[1, 1], [2, 2], [3, 3], [1, 2], [1, 3]])
# input dimension # input dimension
@pytest.mark.parametrize("ifm_dim", [[8,8],[13,13],[1,12]]) @pytest.mark.parametrize(
"ifm_dim", [[8, 8], [13, 13], [1, 11], [1, 12], [1, 13], [1, 14]]
)
# input channels # input channels
@pytest.mark.parametrize("ifm_ch", [6]) @pytest.mark.parametrize("ifm_ch", [6])
# Stride # Stride
@pytest.mark.parametrize("stride", [[1,1],[2,2],[3,4]]) @pytest.mark.parametrize("stride", [[1, 1], [2, 2], [1, 2]])
# Dilation # Dilation
@pytest.mark.parametrize("dilation", [[1,1],[2,2],[4,3]]) @pytest.mark.parametrize("dilation", [[1, 1], [2, 2], [1, 3]])
# depthwise # depthwise
@pytest.mark.parametrize("dw", [0,1]) @pytest.mark.parametrize("dw", [0, 1])
# input channel parallelism ("SIMD") # input channel parallelism ("SIMD")
@pytest.mark.parametrize("simd", [1,2,3,6]) @pytest.mark.parametrize("simd", [1, 2, 3, 6])
# parallel_window enable (MMV_out = M*K) # parallel_window enable (MMV_out = M*K)
@pytest.mark.parametrize("parallel_window", [0,1]) @pytest.mark.parametrize("parallel_window", [0, 1])
# in/out MMV ("M") # in/out MMV ("M")
@pytest.mark.parametrize("m", [1]) @pytest.mark.parametrize("m", [1])
# Flip dimensions # Flip dimensions
@pytest.mark.parametrize("flip", [False,True]) @pytest.mark.parametrize("flip", [False])
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.vivado @pytest.mark.vivado
def test_fpgadataflow_slidingwindow_rtl( def test_fpgadataflow_slidingwindow_rtl(
idt, k, ifm_dim, ifm_ch, stride, dilation, dw, simd, m, parallel_window, flip idt, k, ifm_dim, ifm_ch, stride, dilation, dw, simd, m, parallel_window, flip
): ):
#ifm_dim = conv_config[0] # ifm_dim = conv_config[0]
#k = conv_config[1] # k = conv_config[1]
#stride = conv_config[2] # stride = conv_config[2]
#dilation= conv_config[3] # dilation= conv_config[3]
if flip: if flip:
if (ifm_dim[0]==ifm_dim[1] and k[0]==k[1] and stride[0]==stride[1] and dilation[0] == dilation[1]): if (
ifm_dim[0] == ifm_dim[1]
and k[0] == k[1]
and stride[0] == stride[1]
and dilation[0] == dilation[1]
):
pytest.skip("Dimension flip would have no effect") pytest.skip("Dimension flip would have no effect")
k = k[::-1] k = k[::-1]
ifm_dim = ifm_dim[::-1] ifm_dim = ifm_dim[::-1]
...@@ -203,21 +209,31 @@ def test_fpgadataflow_slidingwindow_rtl( ...@@ -203,21 +209,31 @@ def test_fpgadataflow_slidingwindow_rtl(
stride_h, stride_w = stride stride_h, stride_w = stride
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
kernel_width = (k_w-1)*dilation_w+1 # incl. dilation kernel_width = (k_w - 1) * dilation_w + 1 # incl. dilation
kernel_height = (k_h-1)*dilation_h+1 # incl. dilation kernel_height = (k_h - 1) * dilation_h + 1 # incl. dilation
if simd > ifm_ch: if simd > ifm_ch:
pytest.skip("SIMD cannot be larger than number of input channels") pytest.skip("SIMD cannot be larger than number of input channels")
if ifm_ch % simd != 0: if ifm_ch % simd != 0:
pytest.skip("SIMD must divide number of input channels") pytest.skip("SIMD must divide number of input channels")
if kernel_height > ifm_dim_h or stride_h > ifm_dim_h: if kernel_height > ifm_dim_h or stride_h > ifm_dim_h:
pytest.skip("Illegal convolution configuration: kernel or stride > FM dimension") pytest.skip(
"Illegal convolution configuration: kernel or stride > FM dimension"
)
if kernel_width > ifm_dim_w or stride_w > ifm_dim_w: if kernel_width > ifm_dim_w or stride_w > ifm_dim_w:
pytest.skip("Illegal convolution configuration: kernel or stride > FM dimension") pytest.skip(
if (k_h==1 and (stride_h!=1 or dilation_h!=1)) or (k_w==1 and (stride_w!=1 or dilation_w!=1)): "Illegal convolution configuration: kernel or stride > FM dimension"
pytest.skip("Illegal convolution configuration: stride or dilation defined for unitary kernel dim") )
if k_h==1 and k_w==1 and simd != ifm_ch: if (k_h == 1 and (stride_h != 1 or dilation_h != 1)) or (
k_w == 1 and (stride_w != 1 or dilation_w != 1)
):
pytest.skip(
"Illegal convolution configuration: stride or dilation defined for unitary kernel dim"
)
if k_h == 1 and k_w == 1 and simd != ifm_ch:
pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)") pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)")
if parallel_window and simd != ifm_ch:
pytest.skip("Parallel window requires SIMD=C")
ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h)
ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w)
...@@ -258,7 +274,7 @@ def test_fpgadataflow_slidingwindow_rtl( ...@@ -258,7 +274,7 @@ def test_fpgadataflow_slidingwindow_rtl(
) )
y_expected = oxe.execute_onnx(golden, input_dict)["outp"] y_expected = oxe.execute_onnx(golden, input_dict)["outp"]
#DEBUG # DEBUG
print("-------expected:") print("-------expected:")
print(y_expected) print(y_expected)
print("--------produced:") print("--------produced:")
...@@ -267,7 +283,7 @@ def test_fpgadataflow_slidingwindow_rtl( ...@@ -267,7 +283,7 @@ def test_fpgadataflow_slidingwindow_rtl(
node = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")[0] node = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")[0]
inst = getCustomOp(node) inst = getCustomOp(node)
cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim")
print("RTLSIM cycles: %d"%cycles_rtlsim) print("RTLSIM cycles: %d" % cycles_rtlsim)
if dw == 0: if dw == 0:
assert (y_produced == y_expected).all() assert (y_produced == y_expected).all()
...@@ -279,6 +295,7 @@ def test_fpgadataflow_slidingwindow_rtl( ...@@ -279,6 +295,7 @@ def test_fpgadataflow_slidingwindow_rtl(
y_expected = y_expected.reshape(1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w) y_expected = y_expected.reshape(1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w)
assert (y_produced == y_expected).all() assert (y_produced == y_expected).all()
# exp_cycles_dict = model.analysis(exp_cycles_per_layer) # exp_cycles_dict = model.analysis(exp_cycles_per_layer)
# exp_cycles = exp_cycles_dict[node.name] # exp_cycles = exp_cycles_dict[node.name]
# assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) # assert np.isclose(exp_cycles, cycles_rtlsim, atol=10)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment