diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index 068950b89ae543f5a37c28d83d87ecfa605eaab4..089d6385b414cc252f3c505e06fab438a3b96e92 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -29,6 +29,9 @@ from finn.custom_op.fpgadataflow.convolutioninputgenerator import ( ConvolutionInputGenerator, ) +from finn.custom_op.fpgadataflow.convolutioninputgenerator1d import ( + ConvolutionInputGenerator1D, +) from finn.custom_op.fpgadataflow.downsampler import DownSampler from finn.custom_op.fpgadataflow.streamingfclayer_batch import StreamingFCLayer_Batch from finn.custom_op.fpgadataflow.streamingmaxpool_batch import StreamingMaxPool_Batch @@ -58,6 +61,7 @@ custom_op["DownSampler"] = DownSampler custom_op["StreamingMaxPool_Batch"] = StreamingMaxPool_Batch custom_op["StreamingFCLayer_Batch"] = StreamingFCLayer_Batch custom_op["ConvolutionInputGenerator"] = ConvolutionInputGenerator +custom_op["ConvolutionInputGenerator1D"] = ConvolutionInputGenerator1D custom_op["TLastMarker"] = TLastMarker custom_op["StreamingDataWidthConverter_Batch"] = StreamingDataWidthConverter_Batch custom_op["StreamingFIFO"] = StreamingFIFO diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator1d.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator1d.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d5c176e3b9175f84a0c62dc7f7d7b702502b44 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator1d.py @@ -0,0 +1,610 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os + +import math +import numpy as np + +from finn.core.datatype import DataType +from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp +from finn.custom_op.general.im2col import compute_conv_output_dim +from onnx import TensorProto, helper +from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy + +# This operation should only be used for 1D convolutions. Either the +# IFMDim_H or IFMDim_W should be '1', which represents the so-called +# dummy-dimension + +# ONNX i/o tensor shape assumptions for ConvolutionInputGenerator1D: +# input 0 is the input tensor, shape NHWC = (1, IFMDim_H, IFMDim_W, IFMChannels) +# output 0 is the output tensor, shape NHWC: +# = (1, OFMDim_H, OFMDim_W, (ConvKernelDim_H*ConvKernelDim_W)*IFMChannels) + +# note: the actual data layout produced by the hlslib kernels is different +# for depthwise and non-depthwise ops. +# * non-depthwise SWG: (1, OFMDim_H, OFMDim_W, K_H, K_W, IFMChannels/SIMD, SIMD) +# * depthwise SWG: (1, OFMDim_H, OFMDim_W, IFMChannels/SIMD, K_H, K_W, SIMD) +# see test_fpgadataflow_slidingwindow.py for an example of how to transform +# between the two layouts + + +class ConvolutionInputGenerator1D(HLSCustomOp): + """Class that corresponds to one of the 1D finn-hlslib ConvolutionInputGenerator + (sliding window) function variants. Depending on the combination of + attributes (e.g. depthwise or not, whether dilation is 0) a different + variant will be picked for the actual HLS implementation.""" + + def __init__(self, onnx_node): + super().__init__(onnx_node) + + def get_nodeattr_types(self): + my_attrs = { + "ConvKernelDim": ("ints", True, []), # [H, W] = [Y, X] + "IFMChannels": ("i", True, 0), + "IFMDim": ("ints", True, []), # [H, W] = [Y, X] + "OFMDim": ("ints", True, []), # [H, W] = [Y, X] + "SIMD": ("i", True, 0), + "Stride": ("ints", True, []), # [H, W] = [Y, X] + "Dilation": ("ints", True, []), # [H, W] = [Y, X] + # FINN DataTypes for inputs, weights, outputs + "inputDataType": ("s", True, ""), + "outputDataType": ("s", True, ""), + "depthwise": ("i", False, 0, {0, 1}), + # FPGA resource type for ConvolutionInputGenerator input buffer + # auto -- let Vivado HLS decide + # block -- use BRAM + # distributed -- use LUTRAM + # ultra -- use URAM + "ram_style": ( + "s", + False, + "distributed", + {"auto", "block", "distributed", "ultra"}, + ), + } + my_attrs.update(super().get_nodeattr_types()) + return my_attrs + + def get_normal_input_shape(self): + ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + ishape = (1, ifm_dim_h, ifm_dim_w, ifm_ch) + return ishape + + def get_folded_input_shape(self): + ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + simd = self.get_nodeattr("SIMD") + assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" + wf = int(ifm_ch / simd) + folded_ishape = (1, ifm_dim_h, ifm_dim_w, wf, simd) + return folded_ishape + + def get_normal_output_shape(self): + k_h, k_w = self.get_nodeattr("ConvKernelDim") + ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + stride_h, stride_w = self.get_nodeattr("Stride") + dilation_h, dilation_w = self.get_nodeattr("Dilation") + pad = 0 + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, pad, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, pad, dilation_w) + oshape = (1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch) + return oshape + + def get_folded_output_shape(self): + k_h, k_w = self.get_nodeattr("ConvKernelDim") + ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + stride_h, stride_w = self.get_nodeattr("Stride") + dilation_h, dilation_w = self.get_nodeattr("Dilation") + simd = self.get_nodeattr("SIMD") + pad = 0 + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, pad, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, pad, dilation_w) + assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" + wf = int((k_h * k_w * ifm_ch) // simd) + folded_oshape = (1, ofm_dim_h, ofm_dim_w, wf, simd) + return folded_oshape + + def make_shape_compatible_op(self, model): + exp_ishape = self.get_normal_input_shape() + oshape = self.get_normal_output_shape() + ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) + assert ishape == exp_ishape, "Unexpect input shape for ConvInpGen." + # implement tensor with correct shape + values = np.random.randn(*oshape).astype(np.float32) + return helper.make_node( + "Constant", + inputs=[], + outputs=[self.onnx_node.output[0]], + value=helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=values.shape, + vals=values.flatten().astype(float), + ), + ) + + def infer_node_datatype(self, model): + node = self.onnx_node + # data type stays the same + dtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], dtype) + + def verify_node(self): + pass + + def get_input_datatype(self): + """Returns FINN DataType of input.""" + return DataType[self.get_nodeattr("inputDataType")] + + def get_output_datatype(self): + """Returns FINN DataType of output.""" + return DataType[self.get_nodeattr("outputDataType")] + + def get_instream_width(self): + """Returns stream width, input and output stream width are equal for + the sliding window function""" + ibits = self.get_input_datatype().bitwidth() + simd = self.get_nodeattr("SIMD") + ifm_ch = self.get_nodeattr("IFMChannels") + assert ifm_ch % simd == 0, "SIMD must divide IFMChannels" + in_width = simd * ibits + return in_width + + def get_outstream_width(self): + """Returns stream width, input and output stream width are equal for + the sliding window function, so the function to determine the input + stream width can be reused.""" + return self.get_instream_width() + + def get_number_output_values(self): + folded_oshape = self.get_folded_output_shape() + num_output_elems = np.prod(folded_oshape[:-1]) + return num_output_elems + + def get_exp_cycles(self): + simd = self.get_nodeattr("SIMD") + ifm_ch = self.get_nodeattr("IFMChannels") + k = self.get_nodeattr("ConvKernelDim") + ifm_dim = self.get_nodeattr("IFMDim") + ofm_dim = self.get_nodeattr("OFMDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + + # see defines() for an explanation + if ifm_dim[1] == 1: + ifm_dim = ifm_dim[::-1] + ofm_dim = ofm_dim[::-1] + k = k[::-1] + stride = stride[::-1] + dilation = dilation[::-1] + + ifm_dim_h, ifm_dim_w = ifm_dim + ofm_dim_h, ofm_dim_w = ofm_dim + k_h, k_w = k + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + + # since mmv != 1 is not supported yet, we set mmv for now to 1 + mmv = 1 + # see https://github.com/Xilinx/finn-hlslib/blob/master/slidingwindow.h + cycles_write_block = (ofm_dim_w * k_w * k_h * (ifm_ch / simd)) / mmv + cycles_read_block = stride_w * ifm_dim_w * (ifm_ch / simd) + max_cycles = max(cycles_write_block, cycles_read_block) + exp_cycles = ( + ifm_dim_w * k_h * dilation_h * (ifm_ch / simd) + ofm_dim_h * max_cycles + ) + + return int(exp_cycles) + + def bram_estimation(self): + # NOTE: not tested for correctness + simd = self.get_nodeattr("SIMD") + ifm_ch = self.get_nodeattr("IFMChannels") + ifm_dim = np.prod(self.get_nodeattr("IFMDim")) + k = np.prod(self.get_nodeattr("ConvKernelDim")) + stride = np.prod(self.get_nodeattr("Stride")) + ram_style = self.get_nodeattr("ram_style") + if ram_style == "block" or ram_style == "auto": + ram_depth = ifm_dim * ifm_ch / simd + if ram_depth <= 512: + ram_width = 36 + elif ram_depth <= 1024: + ram_width = 18 + elif ram_depth <= 2048: + ram_width = 9 + elif ram_depth <= 4096: + ram_width = 4 + elif ram_depth <= 8192: + ram_width = 2 + else: + ram_width = 1 + return int( + (k + stride) + * ( + math.ceil(simd * self.get_input_datatype().bitwidth() / ram_width) + * math.ceil(ifm_dim * ifm_ch / simd / ram_depth) + ) + ) + else: + return 0 + + def lut_estimation(self): + # NOTE: not tested for correctness + simd = self.get_nodeattr("SIMD") + ifm_ch = self.get_nodeattr("IFMChannels") + ifm_dim = np.prod(self.get_nodeattr("IFMDim")) + k = np.prod(self.get_nodeattr("ConvKernelDim")) + stride = np.prod(self.get_nodeattr("Stride")) + ram_style = self.get_nodeattr("ram_style") + if ram_style == "distributed": + ram_luts = int( + (k + stride) + * ( + simd + * self.get_input_datatype().bitwidth() + * math.ceil(ifm_dim * ifm_ch / simd / 64) + ) + ) + else: + ram_luts = 0 + return 300 + ram_luts + + def uram_estimation(self): + # NOTE: not tested for correctness + simd = self.get_nodeattr("SIMD") + ifm_ch = self.get_nodeattr("IFMChannels") + ifm_dim = np.prod(self.get_nodeattr("IFMDim")) + k = np.prod(self.get_nodeattr("ConvKernelDim")) + stride = np.prod(self.get_nodeattr("Stride")) + ram_style = self.get_nodeattr("ram_style") + if ram_style == "ultra": + return int( + (k + stride) + * ( + math.ceil(simd * self.get_input_datatype().bitwidth() / 64) + * math.ceil(ifm_dim * ifm_ch / simd / 4096) + ) + ) + else: + return 0 + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + node = self.onnx_node + exp_ishape = self.get_normal_input_shape() + exp_oshape = self.get_normal_output_shape() + folded_ishape = self.get_folded_input_shape() + folded_oshape = self.get_folded_output_shape() + + # TODO ensure codegen dir exists + if mode == "cppsim": + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + elif mode == "rtlsim": + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following value ("cppsim", "rtlsim")""".format( + mode + ) + ) + + inp = context[node.input[0]] + assert str(inp.dtype) == "float32", "Input datatype is not float32" + assert ( + inp.shape == exp_ishape + ), """Input shape doesn't + match expected shape (1, ifm_dim, ifm_dim, ifm_ch).""" + if self.get_input_datatype() == DataType.BIPOLAR: + # store bipolar activations as binary + inp = (inp + 1) / 2 + export_idt = DataType.BINARY + else: + export_idt = self.get_input_datatype() + # reshape input into folded form + inp = inp.reshape(folded_ishape) + # make copy before saving array + reshaped_input = inp.copy() + np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input) + + if mode == "cppsim": + # execute the precompiled model + super().exec_precompiled_singlenode_model() + # load output npy file + super().npy_to_dynamic_output(context) + assert ( + context[node.output[0]].shape == folded_oshape + ), "cppsim \ + did not produce expected ofolded utput shape" + context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape) + elif mode == "rtlsim": + sim = self.get_rtlsim() + nbits = self.get_instream_width() + rtlsim_inp = npy_to_rtlsim_input( + "{}/input_0.npy".format(code_gen_dir), export_idt, nbits + ) + super().reset_rtlsim(sim) + super().toggle_clk(sim) + rtlsim_output = self.rtlsim(sim, rtlsim_inp) + odt = export_idt + target_bits = odt.bitwidth() + packed_bits = self.get_outstream_width() + out_npy_path = "{}/output.npy".format(code_gen_dir) + out_shape = self.get_folded_output_shape() + rtlsim_output_to_npy( + rtlsim_output, out_npy_path, odt, out_shape, packed_bits, target_bits + ) + # load and reshape output + output = np.load(out_npy_path) + output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape) + context[node.output[0]] = output + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following value ("cppsim", "rtlsim")""".format( + mode + ) + ) + # binary -> bipolar if needed + if self.get_output_datatype() == DataType.BIPOLAR: + out = context[node.output[0]] + out = 2 * out - 1 + context[node.output[0]] = out + assert ( + context[node.output[0]].shape == exp_oshape + ), """Output + shape doesn't match expected shape (1, ofm_dim_h, ofm_dim_w, k_h*k_w*ifm_ch).""" + + def global_includes(self): + self.code_gen_dict["$GLOBALS$"] = ['#include "slidingwindow.h"'] + + def defines(self, var): + numReps = 1 + ifm_dim = self.get_nodeattr("IFMDim") + ifm_ch = self.get_nodeattr("IFMChannels") + ofm_dim = self.get_nodeattr("OFMDim") + k = self.get_nodeattr("ConvKernelDim") + stride = self.get_nodeattr("Stride") + dilation = self.get_nodeattr("Dilation") + simd = self.get_nodeattr("SIMD") + ifm_precision = self.get_input_datatype().bitwidth() + + # For the kernel, presenting the input data of size D as + # [H, W] = [Y, X] = [1, D] or [D, 1] + # effectively gives the same result. Because the + # ConvolutionInputGenerator_NonSquare_Dilated(_dws) kernel currently only + # supports dilation>1 along the X-axis and the + # ConvolutionInputGenerator_NonSquare only works for stride>1 along the + # X-axis, we are working with the following assumption: + # the dummy ('1') dimension is the Y-dimension, i.e. + # images and kernels (and their attributes) of dimension + # [H, W] = [Y, X] = [D, 1] or [1, D] are always mapped to [1, D] + if ifm_dim[1] == 1: + ifm_dim = ifm_dim[::-1] + ofm_dim = ofm_dim[::-1] + k = k[::-1] + stride = stride[::-1] + dilation = dilation[::-1] + + ifm_dim_y, ifm_dim_x = ifm_dim + ofm_dim_y, ofm_dim_x = ofm_dim + k_y, k_x = k + dilation_y, dilation_x = dilation + # For a 1d convolution with stride=[S,1] or [1,S], the finn-hlslib function + # of ConvInpGen must be created with [stride_y, stride_x] = [S, S]. + # TODO: changes in finn-hlslib (slidingwindow.h) + stride_y = np.prod(stride) + stride_x = np.prod(stride) + + if dilation_x > 1: + assert ( + dilation_y == 1 + ), "Dilation value greater than 1 along y-axis is not yet supported" + self.code_gen_dict["$DEFINES$"] = [ + """ + #define ConvKernelDim1_x {}\n + #define ConvKernelDim1_y {}\n + #define IFMChannels1 {}\n + #define Input_precision1 {}\n + #define IFMDim1_x {}\n + #define IFMDim1_y {}\n + #define OFMDim1_x {}\n + #define OFMDim1_y {}\n + #define SIMD1 {}\n + #define Stride1_x {}\n + #define Stride1_y {}\n + #define Dilation1_x {}\n + #define Dilation1_y {}\n + #define numReps {} + """.format( + k_x, + k_y, + ifm_ch, + ifm_precision, + ifm_dim_x, + ifm_dim_y, + ofm_dim_x, + ofm_dim_y, + simd, + stride_x, + stride_y, + dilation_x, + dilation_y, + numReps, + ) + ] + else: + ofm_dim = self.get_nodeattr("OFMDim") + self.code_gen_dict["$DEFINES$"] = [ + """ + #define ConvKernelDim1_x {}\n + #define ConvKernelDim1_y {}\n + #define IFMChannels1 {}\n + #define Input_precision1 {}\n + #define IFMDim1_x {}\n + #define IFMDim1_y {}\n + #define OFMDim1_x {}\n + #define OFMDim1_y {}\n + #define SIMD1 {}\n + #define Stride1_x {}\n + #define Stride1_y {}\n + #define numReps {} + """.format( + k_x, + k_y, + ifm_ch, + ifm_precision, + ifm_dim_x, + ifm_dim_y, + ofm_dim_x, + ofm_dim_y, + simd, + stride_x, + stride_y, + numReps, + ) + ] + + def read_npy_data(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + dtype = self.get_input_datatype() + if dtype == DataType.BIPOLAR: + # use binary for bipolar storage + dtype = DataType.BINARY + elem_bits = dtype.bitwidth() + packed_bits = self.get_instream_width() + packed_hls_type = "ap_uint<%d>" % packed_bits + elem_hls_type = dtype.get_hls_datatype_str() + npy_type = "float" + npy_in = "%s/input_0.npy" % code_gen_dir + self.code_gen_dict["$READNPYDATA$"] = [] + self.code_gen_dict["$READNPYDATA$"].append( + 'npy2apintstream<%s, %s, %d, %s>("%s", in0);' + % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in) + ) + + def strm_decl(self): + self.code_gen_dict["$STREAMDECLARATIONS$"] = [] + self.code_gen_dict["$STREAMDECLARATIONS$"].append( + 'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width()) + ) + self.code_gen_dict["$STREAMDECLARATIONS$"].append( + 'hls::stream<ap_uint<{}>> out ("out");'.format(self.get_outstream_width()) + ) + + def docompute(self): + ram_style = self.get_nodeattr("ram_style") + map_to_hls_ram_style = { + "auto": "ap_resource_dflt()", + "block": "ap_resource_bram()", + "distributed": "ap_resource_lutram()", + "ultra": "ap_resource_uram()", + } + hls_ram_style = map_to_hls_ram_style[ram_style] + hls_call = "ConvolutionInputGenerator" + # check which ConvolutionInputGenerator is needed + dilation_h, dilation_w = self.get_nodeattr("Dilation") + + hls_call += "_NonSquare" + if dilation_h > 1 or dilation_w > 1: + hls_call += "_Dilated" + if self.get_nodeattr("depthwise") == 1: + hls_call += "_dws" + self.code_gen_dict["$DOCOMPUTE$"] = [ + """{}<ConvKernelDim1_x, ConvKernelDim1_y, IFMChannels1, Input_precision1, + IFMDim1_x, IFMDim1_y, OFMDim1_x, OFMDim1_y, SIMD1, Stride1_x, Stride1_y, + Dilation1_x, Dilation1_y> (in0, out, numReps, {});""".format( + hls_call, hls_ram_style + ) + ] + elif self.get_nodeattr("depthwise") == 1: + hls_call += "_dws" + self.code_gen_dict["$DOCOMPUTE$"] = [ + """{}<ConvKernelDim1_x, ConvKernelDim1_y, IFMChannels1, Input_precision1, + IFMDim1_x, IFMDim1_y, OFMDim1_x, OFMDim1_y, SIMD1, Stride1_x, Stride1_y> + (in0, out, numReps, {});""".format( + hls_call, hls_ram_style + ) + ] + else: + self.code_gen_dict["$DOCOMPUTE$"] = [ + """{}<ConvKernelDim1_x, ConvKernelDim1_y, IFMChannels1, Input_precision1, + IFMDim1_x, IFMDim1_y, OFMDim1_x, OFMDim1_y, SIMD1, Stride1_x, Stride1_y> + (in0, out, numReps, {});""".format( + hls_call, hls_ram_style + ) + ] + + def dataoutstrm(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + dtype = self.get_output_datatype() + if dtype == DataType.BIPOLAR: + # use binary for bipolar storage + dtype = DataType.BINARY + elem_bits = dtype.bitwidth() + packed_bits = self.get_outstream_width() + packed_hls_type = "ap_uint<%d>" % packed_bits + elem_hls_type = dtype.get_hls_datatype_str() + npy_type = "float" + npy_out = "%s/output.npy" % code_gen_dir + oshape = self.get_folded_output_shape() + oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}") + + self.code_gen_dict["$DATAOUTSTREAM$"] = [ + 'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s");' + % ( + packed_hls_type, + elem_hls_type, + elem_bits, + npy_type, + oshape_cpp_str, + npy_out, + ) + ] + + def save_as_npy(self): + self.code_gen_dict["$SAVEASCNPY$"] = [] + + def blackboxfunction(self): + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + """void {}(hls::stream<ap_uint<SIMD1*Input_precision1>> &in0, + hls::stream<ap_uint<SIMD1*Input_precision1>> &out)""".format( + self.onnx_node.name + ) + ] + + def pragmas(self): + self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"] + self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out") + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS INTERFACE ap_ctrl_none port=return" + ) diff --git a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator1d.py b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator1d.py new file mode 100644 index 0000000000000000000000000000000000000000..6c83aab0d683cdb3888aca3c46bb339bd6330917 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator1d.py @@ -0,0 +1,256 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest +import numpy as np + +from onnx import TensorProto, helper + +import finn.core.onnx_exec as oxe +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.general import GiveUniqueNodeNames +from finn.util.basic import gen_finn_dt_tensor + +from finn.custom_op.registry import getCustomOp +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.custom_op.general.im2col import compute_conv_output_dim + + +def make_single_im2col_modelwrapper( + k, ifm_ch, ifm_dim, ofm_dim, simd, stride, dilation, idt +): + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h, ofm_dim_w = ofm_dim + + odt = idt + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_dim_h, ifm_dim_w, ifm_ch] + ) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch] + ) + + im2col_node = helper.make_node( + "Im2Col", + ["inp"], + ["outp"], + domain="finn.custom_op.general", + stride=[stride_h, stride_w], + kernel_size=[k_h, k_w], + input_shape=str((1, ifm_dim_h, ifm_dim_w, ifm_ch)), + dilations=[dilation_h, dilation_w], + pad_amount=[0, 0, 0, 0], + pad_value=0, + ) + graph = helper.make_graph( + nodes=[im2col_node], name="im2col_graph", inputs=[inp], outputs=[outp] + ) + + model = helper.make_model(graph, producer_name="im2col-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + + return model + + +def make_single_slidingwindow_modelwrapper( + k, ifm_ch, ifm_dim, ofm_dim, simd, stride, dilation, idt, dw=0 +): + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + ofm_dim_h, ofm_dim_w = ofm_dim + + odt = idt + inp = helper.make_tensor_value_info( + "inp", TensorProto.FLOAT, [1, ifm_dim_h, ifm_dim_w, ifm_ch] + ) + outp = helper.make_tensor_value_info( + "outp", TensorProto.FLOAT, [1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch] + ) + + SlidingWindow_node = helper.make_node( + "ConvolutionInputGenerator1D", + ["inp"], + ["outp"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ConvKernelDim=[k_h, k_w], + IFMChannels=ifm_ch, + IFMDim=[ifm_dim_h, ifm_dim_w], + OFMDim=[ofm_dim_h, ofm_dim_w], + SIMD=simd, + Stride=[stride_h, stride_w], + Dilation=[dilation_h, dilation_w], + inputDataType=idt.name, + outputDataType=odt.name, + depthwise=dw, + ) + graph = helper.make_graph( + nodes=[SlidingWindow_node], + name="slidingwindow_graph", + inputs=[inp], + outputs=[outp], + ) + + model = helper.make_model(graph, producer_name="slidingwindow-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + + return model + + +def prepare_inputs(input_tensor): + return {"inp": input_tensor} + + +# input datatype +# @pytest.mark.parametrize("idt", [DataType.BIPOLAR, DataType.INT8]) +@pytest.mark.parametrize("idt", [DataType.INT8]) +# kernel size +@pytest.mark.parametrize("k", [[4, 1]]) +# input dimension +@pytest.mark.parametrize("ifm_dim", [[10, 1]]) +# input channels +@pytest.mark.parametrize("ifm_ch", [1, 4]) +# Stride +@pytest.mark.parametrize("stride", [[1, 1], [2, 1]]) +# Dilation +# @pytest.mark.parametrize("dilation", [[1, 1], [2, 1]]) +@pytest.mark.parametrize("dilation", [[1, 1]]) +# execution mode +@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) +# input channel parallelism ("SIMD") +@pytest.mark.parametrize("simd", [1, 4]) +# depthwise +@pytest.mark.parametrize("dw", [0, 1]) +# Flip dimensions +@pytest.mark.parametrize("flip", [False, True]) +@pytest.mark.slow +@pytest.mark.vivado +def test_fpgadataflow_slidingwindow_1d( + idt, k, ifm_dim, ifm_ch, stride, dilation, exec_mode, simd, dw, flip +): + if flip: + k = k[::-1] + ifm_dim = ifm_dim[::-1] + stride = stride[::-1] + dilation = dilation[::-1] + + k_h, k_w = k + ifm_dim_h, ifm_dim_w = ifm_dim + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + + if (dilation_h > 1 or dilation_w > 1) and (stride_h > 1 or stride_w > 1): + pytest.skip( + """Dilation value greater than 1 and stride greater than 1 + currently not supported for 1D convolutions""" + ) + if simd > ifm_ch: + pytest.skip("SIMD cannot be larger than number of input channels") + + ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h) + ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w) + ofm_dim = [ofm_dim_h, ofm_dim_w] + + x = gen_finn_dt_tensor(idt, (1, ifm_dim_h, ifm_dim_w, ifm_ch)) + model = make_single_slidingwindow_modelwrapper( + k=k, + ifm_ch=ifm_ch, + ifm_dim=ifm_dim, + ofm_dim=ofm_dim, + simd=simd, + stride=stride, + dilation=dilation, + idt=idt, + dw=dw, + ) + + if exec_mode == "cppsim": + model = model.transform(SetExecMode("cppsim")) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + elif exec_mode == "rtlsim": + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xc7z020clg400-1", 5)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + else: + raise Exception("Unknown exec_mode in test_fpgadataflow_slidingwindow") + + # prepare input data + input_dict = prepare_inputs(x) + # execute model + y_produced = oxe.execute_onnx(model, input_dict)["outp"] + golden = make_single_im2col_modelwrapper( + k=k, + ifm_ch=ifm_ch, + ifm_dim=ifm_dim, + ofm_dim=ofm_dim, + simd=simd, + stride=stride, + dilation=dilation, + idt=idt, + ) + y_expected = oxe.execute_onnx(golden, input_dict)["outp"] + + if dw == 0: + assert (y_produced == y_expected).all() + else: + y_expected = y_expected.reshape( + 1, ofm_dim_h, ofm_dim_w, k_h * k_w, ifm_ch // simd, simd + ) + y_expected = y_expected.transpose(0, 1, 2, 4, 3, 5) + y_expected = y_expected.reshape(1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w) + assert (y_produced == y_expected).all() + + if exec_mode == "rtlsim": + node = model.get_nodes_by_op_type("ConvolutionInputGenerator1D")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0