diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn index 97f4bf69da28844ec16bbf920da3e18e62e4ebed..2404faafbabe437fc9f36bf77722e8b3c641553f 100644 --- a/docker/Dockerfile.finn +++ b/docker/Dockerfile.finn @@ -92,7 +92,7 @@ ARG FINN_EXP_COMMIT="af6102769226b82b639f243dc36f065340991513" ARG BREVITAS_COMMIT="a5b71d6de1389d3e7db898fef72e014842670f03" ARG PYVERILATOR_COMMIT="0c3eb9343500fc1352a02c020a736c8c2db47e8e" ARG CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4" -ARG HLSLIB_COMMIT="966d17d3fddd801927b2167627d23a9a15ed1461" +ARG HLSLIB_COMMIT="bcca5d2b69c88e9ad7a86581ec062a9756966367" ARG OMX_COMMIT="1dfc4aa2f2895632742cd5751520c6b472feb74e" ARG AVNET_BDF_COMMIT="2d49cfc25766f07792c0b314489f21fe916b639b" diff --git a/src/finn/custom_op/fpgadataflow/pool_batch.py b/src/finn/custom_op/fpgadataflow/pool_batch.py index ba8a446f2cf7541c0bd2e1dff731afe2397942ef..708a3a149abe268d122d339a5c25648630a01ff6 100644 --- a/src/finn/custom_op/fpgadataflow/pool_batch.py +++ b/src/finn/custom_op/fpgadataflow/pool_batch.py @@ -38,7 +38,7 @@ class Pool_Batch(HLSCustomOp): """Class that corresponds to finn-hlslib Pool_batch function. Requires ConvolutionInputGenerator(depthwise == 1) to format its input - Input shape (BatchSize,OutImgDim,OutImgDim,KernelSize^2*Channels) + Input shape (BatchSize,OutImgDim,OutImgDim,TotalKernelSize*Channels) Output shape (BatchSize,OutImgDim,OutImgDim,Channels) Notes: @@ -56,13 +56,13 @@ class Pool_Batch(HLSCustomOp): my_attrs = { "Channels": ("i", True, 0), "PE": ("i", True, 1), - "KernelSize": ("i", True, 0), + "KernelSize": ("ints", True, []), # Function: # - MaxPool # - QuantAvgPool # TODO add support for AvgPool and AccPool "Function": ("s", True, "", {"MaxPool", "QuantAvgPool"}), - "OutImgDim": ("i", True, 0), + "OutImgDims": ("ints", True, []), # FINN DataTypes for inputs/outputs "InputDataType": ("s", True, ""), "OutputDataType": ("s", True, ""), @@ -100,10 +100,11 @@ class Pool_Batch(HLSCustomOp): def get_normal_input_shape(self): ifm_ch = self.get_nodeattr("Channels") - odim = self.get_nodeattr("OutImgDim") + odims = self.get_nodeattr("OutImgDims") batch_size = self.get_nodeattr("BatchSize") k = self.get_nodeattr("KernelSize") - ishape = (batch_size, odim, odim, k * k * ifm_ch) + k_prod = int(np.prod(k)) + ishape = (batch_size, *odims, k_prod * ifm_ch) return ishape def get_folded_input_shape(self): @@ -117,9 +118,9 @@ class Pool_Batch(HLSCustomOp): def get_normal_output_shape(self): ofm_ch = self.get_nodeattr("Channels") - odim = self.get_nodeattr("OutImgDim") + odims = self.get_nodeattr("OutImgDims") batch_size = self.get_nodeattr("BatchSize") - oshape = (batch_size, odim, odim, ofm_ch) + oshape = (batch_size, *odims, ofm_ch) return oshape def get_folded_output_shape(self): @@ -140,9 +141,10 @@ class Pool_Batch(HLSCustomOp): ifm_ch = self.get_nodeattr("Channels") pe = self.get_nodeattr("PE") k = self.get_nodeattr("KernelSize") - odim = self.get_nodeattr("OutImgDim") + k_prod = int(np.prod(k)) + odims = self.get_nodeattr("OutImgDims") batch_size = self.get_nodeattr("BatchSize") - exp_cycles = ((ifm_ch * k * k) / pe) * odim * odim * batch_size + exp_cycles = ((ifm_ch * k_prod) / pe) * np.prod(odims) * batch_size return int(exp_cycles) def get_instream_width(self): @@ -211,10 +213,12 @@ class Pool_Batch(HLSCustomOp): self.code_gen_dict["$DEFINES$"] += ["#define PE {}".format(pe)] k = self.get_nodeattr("KernelSize") - self.code_gen_dict["$DEFINES$"] += ["#define KernelSize {}".format(k)] + k_prod = int(np.prod(k)) + self.code_gen_dict["$DEFINES$"] += ["#define KernelSize {}".format(k_prod)] - odim = self.get_nodeattr("OutImgDim") - self.code_gen_dict["$DEFINES$"] += ["#define OFMDim {}".format(odim)] + odims = self.get_nodeattr("OutImgDims") + total_odim = np.prod(odims) + self.code_gen_dict["$DEFINES$"] += ["#define OFMDimTotal {}".format(total_odim)] numReps = self.get_nodeattr("BatchSize") self.code_gen_dict["$DEFINES$"] += ["#define numReps {}".format(numReps)] @@ -275,7 +279,7 @@ class Pool_Batch(HLSCustomOp): self.code_gen_dict["$DOCOMPUTE$"] += [ """Pool_batch<Channels, PE, KernelSize,Slice<{} >, Slice< {} > > - (in0,out, pool_fxn, OFMDim*OFMDim*numReps);""".format( + (in0,out, pool_fxn, OFMDimTotal*numReps);""".format( i_hls_dt, o_hls_dt ) ] diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index 60ae3fc75effde53c3271bfcf85fe9e1c0013fdf..b2f50b1a23f85bf782c553057148173b6f94dde4 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -385,62 +385,57 @@ class InferPool_Batch(Transformation): graph = model.graph node_ind = 0 graph_modified = False - for n in graph.node: + for node in graph.node: node_ind += 1 - if n.op_type in ["MaxPool", "QuantAvgPool2d", "MaxPoolNHWC"]: - # extract pool parameters + if node.op_type in ["MaxPool", "QuantAvgPool2d", "MaxPoolNHWC"]: + node_input = node.input[0] + ishape = model.get_tensor_shape(node_input) + node_output = node.output[0] + idt = model.get_tensor_datatype(node_input) + oshape = model.get_tensor_shape(node_output) + # only support 4D input tensors (1D convs need extra dummy dim) + if len(ishape) != 4: + continue - if n.op_type == "MaxPool": - k = get_by_name(n.attribute, "kernel_shape").ints[-1] - stride = get_by_name(n.attribute, "strides").ints[-1] - # assumed datalayout + # extract pool parameters + if node.op_type == "MaxPool": + kh, kw = list(get_by_name(node.attribute, "kernel_shape").ints) + sh, sw = list(get_by_name(node.attribute, "strides").ints) dlayout = "NCHW" - elif n.op_type == "QuantAvgPool2d": - inst = getCustomOp(n) - k = inst.get_nodeattr("kernel") - stride = inst.get_nodeattr("stride") + elif node.op_type == "QuantAvgPool2d": + inst = getCustomOp(node) + # QuantAvgPool2d has a single scalar attribute + # for kernel size and stride (implicit square) + kh = kw = inst.get_nodeattr("kernel") + sh = sw = inst.get_nodeattr("stride") dlayout = inst.get_nodeattr("data_layout") - elif n.op_type == "MaxPoolNHWC": - inst = getCustomOp(n) - k_shape = inst.get_nodeattr("kernel_shape") - strides = inst.get_nodeattr("strides") - assert k_shape[0] == k_shape[1] - assert strides[0] == strides[1] - k = k_shape[0] - stride = strides[0] + elif node.op_type == "MaxPoolNHWC": + inst = getCustomOp(node) + kh, kw = inst.get_nodeattr("kernel_shape") + sh, sw = inst.get_nodeattr("strides") dlayout = "NHWC" try: - pad = get_by_name(n.attribute, "pads").ints[-1] + pad = list(get_by_name(node.attribute, "pads").ints) except AttributeError: - pad = 0 - - node_input = n.input[0] - node_output = n.output[0] - idt = model.get_tensor_datatype(node_input) + pad = [0, 0, 0, 0] if not idt.is_integer(): continue - if k < stride: + if (kh < sh) or (kw < sw): + # TODO check/implement swg support continue - elif k == stride: - warnings.warn( - n.name - + """: Inferring Pool_Batch node for k == stride. - This case can be optimized. - For example, for MaxPool run InferStreamingMaxPool before - InferPool_Batch """ - ) odt = model.get_tensor_datatype(node_output) if dlayout == "NCHW": - ifm_ch = model.get_tensor_shape(n.input[0])[1] + _, ifm_ch, ifm_h, ifm_w = ishape + _, ofm_ch, ofm_h, ofm_w = oshape + elif dlayout == "NHWC": + _, ifm_h, ifm_w, ifm_ch = ishape + _, ofm_h, ofm_w, ofm_ch = oshape else: - ifm_ch = model.get_tensor_shape(n.input[0])[-1] - ofm_ch = ifm_ch - ifm_dim = model.get_tensor_shape(n.input[0])[-2] - ofm_dim = model.get_tensor_shape(n.output[0])[-2] + raise Exception("Unknown dlayout: " + str(dlayout)) # if data layout NCHW, we need transpose nodes surrounding # the hls layer @@ -449,7 +444,7 @@ class InferPool_Batch(Transformation): inp_trans_out = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - (1, ifm_dim, ifm_dim, ifm_ch), # NHWC + (1, ifm_h, ifm_w, ifm_ch), # NHWC ) graph.value_info.append(inp_trans_out) inp_trans_out = inp_trans_out.name @@ -458,7 +453,7 @@ class InferPool_Batch(Transformation): pool_output = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - (1, ofm_dim, ofm_dim, ofm_ch), + (1, ofm_h, ofm_w, ofm_ch), ) graph.value_info.append(pool_output) pool_output = pool_output.name @@ -467,7 +462,7 @@ class InferPool_Batch(Transformation): im2col_out = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - (1, ofm_dim, ofm_dim, ifm_ch * k * k), + (1, ofm_h, ofm_w, ifm_ch * kh * kw), ) graph.value_info.append(im2col_out) im2col_out = im2col_out.name @@ -485,24 +480,28 @@ class InferPool_Batch(Transformation): pool_output = node_output accum_bits = 0 - pool_size_param = k + pool_size_param = 0 # will be overridden if neededs pad_value = 0 - if n.op_type in ["MaxPool", "MaxPoolNHWC"]: + if node.op_type in ["MaxPool", "MaxPoolNHWC"]: pool_fxn = "MaxPool" odt = idt pad_value = idt.min() - elif n.op_type == "QuantAvgPool2d": + elif node.op_type == "QuantAvgPool2d": assert odt.is_integer(), """Output data type for QuantAvgPool2d needs to be integer""" - assert pad == 0, "Padding is not supported for QuantAvgPool2d" - inst = getCustomOp(n) + assert all( + x == 0 for x in pad + ), "Padding is not supported for QuantAvgPool2d" + inst = getCustomOp(node) pool_fxn = "QuantAvgPool" pool_size_param = inst.get_shifts() accum_bits = inst.get_accum_size() else: raise Exception( - "pad_value and pool_fxn not configured for {}".format(n.op_type) + "pad_value and pool_fxn not configured for {}".format( + node.op_type + ) ) # format input tensor @@ -511,13 +510,13 @@ class InferPool_Batch(Transformation): [im2col_in], [im2col_out], domain="finn.custom_op.general", - stride=[stride, stride], - kernel_size=[k, k], - pad_amount=[pad, pad, pad, pad], + stride=[sh, sw], + kernel_size=[kh, kw], + pad_amount=pad, pad_value=pad_value, depthwise=1, - input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch), - name="Im2Col_" + n.name, + input_shape="(1,{},{},{})".format(ifm_h, ifm_w, ifm_ch), + name="Im2Col_" + node.name, ) # Warning PE has to be equal to ifm_ch until Im2Col is replaced by @@ -534,13 +533,13 @@ class InferPool_Batch(Transformation): OutputDataType=odt.name, Channels=ifm_ch, PE=ifm_ch, - KernelSize=k, + KernelSize=[kh, kw], Function=pool_fxn, - OutImgDim=ofm_dim, + OutImgDims=[ofm_h, ofm_w], AccumBits=accum_bits, Size=pool_size_param, BatchSize=1, - name="Pool_Batch_" + n.name, + name="Pool_Batch_" + node.name, ) if dlayout == "NCHW": @@ -559,7 +558,7 @@ class InferPool_Batch(Transformation): graph.node.insert(node_ind, im2col_node) graph.node.insert(node_ind + 1, pool_node) # remove old node - graph.node.remove(n) + graph.node.remove(node) graph_modified = True if graph_modified: diff --git a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py index 3efafc040df07a7d56638bf5ce0b1ce01887343c..0dd9991b2ff07a35c923afeda854352213f8ca09 100644 --- a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py +++ b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py @@ -48,22 +48,31 @@ from finn.transformation.infer_shapes import InferShapes from finn.util.basic import gen_finn_dt_tensor -def make_single_maxpool_modelwrapper(k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt): +def make_single_maxpool_modelwrapper( + k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt, use_1d=False +): odt = idt - inp = helper.make_tensor_value_info( - "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim] - ) - outp = helper.make_tensor_value_info( - "outp", TensorProto.FLOAT, [1, ifm_ch, ofm_dim, ofm_dim] - ) - + if use_1d: + ishape = [1, ifm_ch, 1, ifm_dim] + oshape = [1, ifm_ch, 1, ofm_dim] + kshape = [1, k] + pads = [0, pad, 0, pad] + strides = [1, stride] + else: + ishape = [1, ifm_ch, ifm_dim, ifm_dim] + oshape = [1, ifm_ch, ofm_dim, ofm_dim] + kshape = [k, k] + pads = [pad, pad, pad, pad] + strides = [stride, stride] + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape) + outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape) mp_node = helper.make_node( "MaxPool", ["inp"], ["outp"], - kernel_shape=[k, k], - pads=[pad, pad, pad, pad], - strides=[stride, stride], + kernel_shape=kshape, + pads=pads, + strides=strides, ) graph = helper.make_graph( nodes=[mp_node], name="mp_graph", inputs=[inp], outputs=[outp] @@ -128,7 +137,7 @@ def prepare_inputs(input_tensor): # number of out channel computed in parallel @pytest.mark.parametrize("pe", [1, 2, 4]) # pool type -@pytest.mark.parametrize("op_type", ["QuantAvgPool2d", "MaxPool"]) +@pytest.mark.parametrize("op_type", ["QuantAvgPool2d", "MaxPool", "MaxPool1D"]) # execution mode @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) @pytest.mark.slow @@ -147,7 +156,14 @@ def test_convert_to_hls_pool_batch( np.random.seed(0) ofm_dim = int(((ifm_dim + 2 * pad - k) / stride) + 1) - x = gen_finn_dt_tensor(idt, (1, ifm_ch, ifm_dim, ifm_dim)) + ishape = (1, ifm_ch, ifm_dim, ifm_dim) + use_1d = False + if op_type == "MaxPool1D": + use_1d = True + ishape = (1, ifm_ch, 1, ifm_dim) + op_type = "MaxPool" + + x = gen_finn_dt_tensor(idt, ishape) # prepare input data input_dict = prepare_inputs(x) if op_type == "MaxPool": @@ -159,7 +175,7 @@ def test_convert_to_hls_pool_batch( pytest.skip("Skipping Maxpool with idt != odt") model = make_single_maxpool_modelwrapper( - k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt + k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt, use_1d ) elif op_type == "QuantAvgPool2d": if pad != 0: @@ -178,16 +194,40 @@ def test_convert_to_hls_pool_batch( new_model = model.transform(to_hls.InferPool_Batch()) new_model = new_model.transform(GiveUniqueNodeNames()) - if ifm_ch != pe: - new_model = new_model.transform(to_hls.InferConvInpGen()) - # Folding - for n in new_model.graph.node: - if n.op_type == "ConvolutionInputGenerator": - inst = getCustomOp(n) - inst.set_nodeattr("SIMD", pe) - elif n.op_type == "Pool_Batch": - inst = getCustomOp(n) - inst.set_nodeattr("PE", pe) + new_model = new_model.transform(to_hls.InferConvInpGen()) + # Folding + for n in new_model.graph.node: + if n.op_type.startswith("ConvolutionInputGenerator"): + inst = getCustomOp(n) + inst.set_nodeattr("SIMD", pe) + elif n.op_type == "Pool_Batch": + inst = getCustomOp(n) + inst.set_nodeattr("PE", pe) + + if stride <= k: + if pad == 0: + assert len(new_model.graph.node) == 4 + assert new_model.graph.node[0].op_type == "Transpose" + assert new_model.graph.node[1].op_type.startswith( + "ConvolutionInputGenerator" + ) + assert new_model.graph.node[2].op_type == "Pool_Batch" + assert new_model.graph.node[3].op_type == "Transpose" + else: + assert len(new_model.graph.node) == 5 + assert new_model.graph.node[0].op_type == "Transpose" + assert new_model.graph.node[1].op_type == "FMPadding_Batch" + assert new_model.graph.node[2].op_type.startswith( + "ConvolutionInputGenerator" + ) + assert new_model.graph.node[3].op_type == "Pool_Batch" + assert new_model.graph.node[4].op_type == "Transpose" + else: + # not currently converted to HLS, node stays as-is + assert len(new_model.graph.node) == 1 + assert new_model.graph.node[0].op_type in ["MaxPool", "QuantAvgPool2d"] + # no need to exec + return if exec_mode == "cppsim": new_model = new_model.transform(SetExecMode("cppsim")) @@ -205,13 +245,6 @@ def test_convert_to_hls_pool_batch( # execute new_model y_produced = oxe.execute_onnx(new_model, input_dict)["outp"] assert (y_produced == y_expected).all() - if stride <= k: - if pad == 0 or ifm_ch == pe: - assert len(new_model.graph.node) == 4 - else: - assert len(new_model.graph.node) == 5 - else: - assert len(new_model.graph.node) == 1 if exec_mode == "rtlsim": node = new_model.get_nodes_by_op_type("Pool_Batch")[0]