diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index b3a54ffbeebb79bf737da4ad60babc21250da8c7..1f3d40e929e29d16790a491bbfd7a4a5033f866f 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -64,30 +64,28 @@ class InferConvInpGen(Transformation): warnings.warn("Input is not int. Can't infer ConvInpGen") continue i2c_inst = getCustomOp(n) - stride = i2c_inst.get_nodeattr("stride") - k_attr = i2c_inst.get_nodeattr("kernel_size") - k_h = k_attr[0] - k_w = k_attr[1] + stride_h, stride_w = i2c_inst.get_nodeattr("stride") + k_h, k_w = i2c_inst.get_nodeattr("kernel_size") pad_attr = i2c_inst.get_nodeattr("pad_amount") pad_h = pad_attr[0] + pad_attr[2] pad_w = pad_attr[1] + pad_attr[3] + dilation_h, dilation_w = i2c_inst.get_nodeattr("dilations") # temporary checks until non-square conv support is finalized - assert pad_h == pad_w, "Non-square images not yet supported." - assert k_h == k_w, "Non-square kernels not yet supported." - k = k_h - pad = pad_attr[0] pad_val = i2c_inst.get_nodeattr("pad_value") depthwise = i2c_inst.get_nodeattr("depthwise") ifm_ch = i2c_in_shape[-1] - ifm_dim = i2c_in_shape[1] - ofm_dim = i2c_out_shape[1] + ifm_dim_h = i2c_in_shape[1] + ifm_dim_w = i2c_in_shape[2] + ofm_dim_h = i2c_out_shape[1] + ofm_dim_w = i2c_out_shape[2] # default params for ConvolutionInputGenerator ConvInpGen_node_idx = node_ind ConvInpGen_input = i2c_input - ConvInpGen_idim = ifm_dim + ConvInpGen_idim_h = ifm_dim_h + ConvInpGen_idim_w = ifm_dim_w - if pad > 0: + if pad_h > 0 or pad_w > 0: # if padding enabled, ensure pad_val supported by DataType # assert dt.allowed(pad_val),"""FMPadding_Batch DataType # must support pad_val""" @@ -95,12 +93,13 @@ class InferConvInpGen(Transformation): pad_val == 0 ), "FMPadding_Batch doesn't currently support pad_val!= 0" - odim_padding = ifm_dim + 2 * pad + odim_padding_h = ifm_dim_h + pad_h + odim_padding_w = ifm_dim_w + pad_w padding_out = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - (1, odim_padding, odim_padding, ifm_ch), + (1, odim_padding_h, odim_padding_w, ifm_ch), ) graph.value_info.append(padding_out) padding_out = padding_out.name @@ -108,7 +107,8 @@ class InferConvInpGen(Transformation): ConvInpGen_node_idx += 1 ConvInpGen_input = padding_out - ConvInpGen_idim = odim_padding + ConvInpGen_idim_h = odim_padding_h + ConvInpGen_idim_w = odim_padding_w padding_node = helper.make_node( "FMPadding_Batch", @@ -116,15 +116,31 @@ class InferConvInpGen(Transformation): [padding_out], domain="finn.custom_op.fpgadataflow", backend="fpgadataflow", - ImgDim=ifm_dim, - Padding=2 * pad, + ImgDim=[ifm_dim_h, ifm_dim_w], + Padding=pad_attr, NumChannels=ifm_ch, inputDataType=dt.name, SIMD=ifm_ch, ) graph.node.insert(node_ind, padding_node) - if stride > 1 and k == 1: + # Ensure that only supported HLS nodes are inserted + is_square_image = ConvInpGen_idim_h == ConvInpGen_idim_w + is_square_kernel = k_h == k_w + is_kernel_pointwise = k_h == 1 and k_w == 1 + is_equal_stride = stride_h == stride_w + is_1d_convolution = (k_h == 1 and k_w > 1 and ifm_dim_h == 1) or ( + k_h > 1 and k_w == 1 and ifm_dim_w == 1 + ) + + if (stride_h > 1 or stride_w > 1) and is_kernel_pointwise: + assert ( + is_square_image + ), "DownSampler currently only supports square input images." + assert is_equal_stride, """DownSampler currently only supports equal stride value + along different axes.""" + ConvInpGen_idim = ConvInpGen_idim_h + stride = stride_h # create DownSampler node ConvInpGen_node = helper.make_node( "DownSampler", @@ -141,22 +157,58 @@ class InferConvInpGen(Transformation): graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node) else: # create equivalent ConvolutionInputGenerator node - ConvInpGen_node = helper.make_node( - "ConvolutionInputGenerator", - [ConvInpGen_input], - [i2c_output], - domain="finn.custom_op.fpgadataflow", - backend="fpgadataflow", - ConvKernelDim=k, - IFMChannels=ifm_ch, - IFMDim=ConvInpGen_idim, - OFMDim=ofm_dim, - SIMD=ifm_ch, - Stride=stride, - inputDataType=dt.name, - outputDataType=dt.name, - depthwise=depthwise, - ) + if ( + is_square_image and is_square_kernel + ): # square images and square kernels + assert is_equal_stride, """Non-equal strides along different axes is not supported + for (non-)square convolutions""" + assert ( + dilation_h == 1 and dilation_w == 1 + ), """Dilation value != 1 is not supported + for square convolutions""" + ConvInpGen_node = helper.make_node( + "ConvolutionInputGenerator", + [ConvInpGen_input], + [i2c_output], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ConvInpGen_idim_h, ConvInpGen_idim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], + SIMD=ifm_ch, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], + inputDataType=dt.name, + outputDataType=dt.name, + depthwise=depthwise, + ) + else: # non-square images and/or kernels + assert ( + is_1d_convolution + ), "ConvultionInputGenerator1D works only for 1D convolutions" + if dilation_h > 1 or dilation_w > 1: + assert ( + stride_h == 1 and stride_w == 1 + ), """Stride value of greater than 1 is not supported for convolutions + with dilation value greater than 1""" + ConvInpGen_node = helper.make_node( + "ConvolutionInputGenerator1D", + [ConvInpGen_input], + [i2c_output], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ConvInpGen_idim_h, ConvInpGen_idim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], + SIMD=ifm_ch, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], + inputDataType=dt.name, + outputDataType=dt.name, + depthwise=depthwise, + ) graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node) # remove old nodes graph.node.remove(n) @@ -684,7 +736,7 @@ class InferVVAU(Transformation): ): sparsity = model.get_tensor_sparsity(n.input[1]) try: - k = sparsity["dw"]["kernel_shape"] + k_h, k_w = sparsity["dw"]["kernel_shape"] except KeyError: raise Exception( """Sparsity doesn't indicate that MatMul @@ -702,25 +754,25 @@ class InferVVAU(Transformation): mm_output = n.output[0] W = model.get_initializer(mm_weight) # infer dense weight tensor from sparse weight matrix - # kernel size k which was extracted above and the value of + # kernel size (k_h, k_w) which was extracted above and the value of # the channels is used. - # the weight matrix has a shape of (k * k * Channels, Channels) + # the weight matrix has a shape of (k_h * k_w * Channels, Channels) # we need to reverse the creation of the sparse weight matrix - # to achieve a weight tensor of shape (Channels, 1, k, k) + # to achieve a weight tensor of shape (Channels, 1, k_h, k_w) channels = int(W.shape[1]) - # transpose to achieve a shape of (k * k * Channels, Channels) + # transpose to achieve a shape of (k_h * k_w * Channels, Channels) W = W.T - # reshape to (Channels, k, k, Channels) to transpose afterwards - # to (Channels, Channels, k, k) - W = W.reshape(channels, k, k, channels) + # reshape to (Channels, k_h, k_w, Channels) to transpose afterwards + # to (Channels, Channels, k_h, k_w) + W = W.reshape(channels, k_h, k_w, channels) W = W.transpose(0, 3, 1, 2) # now we can extract the values using a for loop over the channels # and fill a zero numpy array in the correct shape - w_tensor = np.zeros((channels, 1, k, k)) + w_tensor = np.zeros((channels, 1, k_h, k_w)) for ch in range(channels): w_tensor[ch][0] = W[ch][ch] model.set_initializer(mm_weight, w_tensor) - model.set_tensor_shape(mm_weight, (channels, 1, k, k)) + model.set_tensor_shape(mm_weight, (channels, 1, k_h, k_w)) # create node with pe=channels as default pe = channels assert ( @@ -762,9 +814,9 @@ class InferVVAU(Transformation): backend="fpgadataflow", resType="lut", PE=pe, - Dim=mm_in_shape[1], + Dim=[mm_in_shape[1], mm_in_shape[2]], Channels=channels, - Kernel=k, + Kernel=[k_h, k_w], inputDataType=idt.name, weightDataType=wdt.name, outputDataType=odt.name, @@ -790,9 +842,9 @@ class InferVVAU(Transformation): backend="fpgadataflow", resType="lut", PE=pe, - Dim=mm_in_shape[1], + Dim=[mm_in_shape[1], mm_in_shape[2]], Channels=channels, - Kernel=k, + Kernel=[k_h, k_w], inputDataType=idt.name, weightDataType=wdt.name, outputDataType=odt.name, @@ -1345,7 +1397,11 @@ class InferGlobalAccPoolLayer(Transformation): ) model.graph.value_info.append(mul_value) model.set_initializer(mul_value.name, np.array(1 / (vecs[1] * vecs[2]))) - new_mul = helper.make_node("Mul", [pool_out, mul_value.name], [result],) + new_mul = helper.make_node( + "Mul", + [pool_out, mul_value.name], + [result], + ) graph.node.insert(insert_point, new_pool) graph.node.insert(insert_point + 1, new_mul) node_ind += 1 diff --git a/tests/fpgadataflow/test_convert_to_hls_1d_conv_layer.py b/tests/fpgadataflow/test_convert_to_hls_1d_conv_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdb21fa72cbbeeb503f7ecc447b659ef7934fb9 --- /dev/null +++ b/tests/fpgadataflow/test_convert_to_hls_1d_conv_layer.py @@ -0,0 +1,189 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from onnx import TensorProto, helper +import numpy as np +import pytest + +from finn.core.datatype import DataType +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.general import GiveUniqueNodeNames +from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul + +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +import finn.core.onnx_exec as oxe +from finn.core.modelwrapper import ModelWrapper +from finn.util.basic import gen_finn_dt_tensor +import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls + +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.custom_op.general.im2col import compute_conv_output_dim +from finn.custom_op.registry import getCustomOp +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer + + +# conv_config: +# [pad_h_begin, pad_w_begin, pad_h_end, pad_w_end] +# [kernel_size_h, kernel_size_w] +# [stride_h, stride_w] +# [dilation_h, dilation_w] +@pytest.mark.parametrize( + "conv_config", + [ + [[0, 0, 0, 0], [4, 1], [1, 1], [1, 1]], + [[1, 0, 1, 0], [4, 1], [1, 1], [1, 1]], + [[1, 0, 1, 0], [4, 1], [2, 1], [1, 1]], + # [[1, 0, 1, 0], [4, 1], [1, 1], [2, 1]] + ], +) +@pytest.mark.parametrize("depthwise", [False, True]) +@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) +@pytest.mark.slow +@pytest.mark.vivado +def test_convert_to_hls_1d_conv_layer(conv_config, depthwise, exec_mode): + pad, kernel_size, stride, dilation = conv_config + np.random.seed(0) + idt = DataType.UINT4 + + in_feature_dim_h, in_feature_dim_w = [10, 1] + in_chn = 16 + + k_h, k_w = kernel_size + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + pad_h = pad[0] + pad[2] + pad_w = pad[1] + pad[3] + + if depthwise is True: + group = out_chn = in_chn + conv_param_shape = [out_chn, 1, k_h, k_w] + else: + group = 1 + out_chn = 20 + conv_param_shape = [out_chn, in_chn, k_h, k_w] + + out_feature_dim_h = compute_conv_output_dim( + in_feature_dim_h, k_h, stride_h, pad_h, dilation_h + ) + out_feature_dim_w = compute_conv_output_dim( + in_feature_dim_w, k_w, stride_w, pad_w, dilation_w + ) + + input_shape = [1, in_chn, in_feature_dim_h, in_feature_dim_w] + output_shape = [1, out_chn, out_feature_dim_h, out_feature_dim_w] + + conv_weight_dt = DataType.UINT4 + + conv_config = {} + conv_config["dilations"] = [dilation_h, dilation_w] + conv_config["group"] = group + conv_config["kernel_shape"] = [k_h, k_w] + conv_config["pads"] = pad + conv_config["strides"] = [stride_h, stride_w] + + top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) + top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape) + value_info = [ + helper.make_tensor_value_info("p1", TensorProto.FLOAT, conv_param_shape) + ] + + modelproto = helper.make_model( + helper.make_graph( + name="conv_test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + helper.make_node("Conv", ["top_in", "p1"], ["top_out"], **conv_config) + ], + ) + ) + + model = ModelWrapper(modelproto) + model.set_tensor_datatype("top_in", idt) + model.set_tensor_datatype("top_out", idt) + model.set_tensor_datatype("p1", conv_weight_dt) + model.set_initializer("p1", gen_finn_dt_tensor(conv_weight_dt, conv_param_shape)) + + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + + new_model = model.transform(LowerConvsToMatMul()) + new_model = new_model.transform(to_hls.InferConvInpGen()) + if depthwise is True: + new_model = new_model.transform(to_hls.InferVVAU()) + else: + new_model = new_model.transform(to_hls.InferQuantizedStreamingFCLayer()) + fc_node = new_model.get_nodes_by_op_type("StreamingFCLayer_Batch")[0] + fc_inst = getCustomOp(fc_node) + mw = fc_inst.get_nodeattr("MW") + mh = fc_inst.get_nodeattr("MH") + pe_cands = list(filter(lambda x: mh % x == 0, range(2, mh + 1))) + simd_cands = list(filter(lambda x: mw % x == 0, range(2, mw + 1))) + fc_inst.set_nodeattr("PE", pe_cands[0]) + fc_inst.set_nodeattr("SIMD", simd_cands[0]) + + new_model = new_model.transform(GiveUniqueNodeNames()) + new_model = new_model.transform(InferShapes()) + new_model = new_model.transform(InferDataTypes()) + + if exec_mode == "cppsim": + new_model = new_model.transform(PrepareCppSim()) + new_model = new_model.transform(CompileCppSim()) + new_model = new_model.transform(SetExecMode("cppsim")) + elif exec_mode == "rtlsim": + new_model = new_model.transform(SetExecMode("rtlsim")) + new_model = new_model.transform(GiveUniqueNodeNames()) + new_model = new_model.transform(PrepareIP("xc7z020clg400-1", 5)) + new_model = new_model.transform(HLSSynthIP()) + new_model = new_model.transform(PrepareRTLSim()) + else: + raise Exception("Unknown exec_mode") + + x = gen_finn_dt_tensor(idt, input_shape) + inp_dict = {model.graph.input[0].name: x} + assert oxe.compare_execution(model, new_model, inp_dict) + + if pad_h == 1 and pad_w == 1: + padding_node = new_model.get_nodes_by_op_type("FMPadding_Batch")[0] + padding_inst = getCustomOp(padding_node) + assert padding_inst.get_nodeattr("SIMD") == in_chn + + if depthwise is True and exec_mode == "rtlsim": + node = new_model.get_nodes_by_op_type("Vector_Vector_Activate_Batch")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = new_model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=11) + assert exp_cycles != 0 diff --git a/tests/fpgadataflow/test_depthwise_convolution.py b/tests/fpgadataflow/test_depthwise_convolution.py index c406d78158c52226fea881c48bc178139653fea5..3efeacb6e6875c6defa799eb7154e02ce880e16a 100644 --- a/tests/fpgadataflow/test_depthwise_convolution.py +++ b/tests/fpgadataflow/test_depthwise_convolution.py @@ -98,7 +98,7 @@ def set_up_reference_model(act, idt, wdt, k, ifm_dim, ifm_ch, stride, padding): inputs=["inp"], outputs=["im2col_out"], kernel_size=[k, k], - stride=stride, + stride=[stride, stride], pad_amount=[padding, padding, padding, padding], input_shape="(1, {}, {}, {})".format(ifm_dim, ifm_dim, ifm_ch), depthwise=1, @@ -142,7 +142,7 @@ def set_up_reference_model(act, idt, wdt, k, ifm_dim, ifm_ch, stride, padding): W_matrix = W_matrix.reshape(ofm_ch, ifm_ch * k * k) model.set_initializer("W_sparse", W_matrix.T) - sparsity = {"dw": {"kernel_shape": k}} + sparsity = {"dw": {"kernel_shape": [k, k]}} model.set_tensor_sparsity("W_sparse", sparsity) if act is not None: