diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index 54bd01e2baf479789baa06853b6a442ca5ee6ef5..0364b272e773a0f498f559cb768a4148224c3a3d 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -385,62 +385,57 @@ class InferPool_Batch(Transformation): graph = model.graph node_ind = 0 graph_modified = False - for n in graph.node: + for node in graph.node: node_ind += 1 - if n.op_type in ["MaxPool", "QuantAvgPool2d", "MaxPoolNHWC"]: - # extract pool parameters + if node.op_type in ["MaxPool", "QuantAvgPool2d", "MaxPoolNHWC"]: + node_input = node.input[0] + ishape = model.get_tensor_shape(node_input) + node_output = node.output[0] + idt = model.get_tensor_datatype(node_input) + oshape = model.get_tensor_shape(node_output) + # only support 4D input tensors (1D convs need extra dummy dim) + if len(ishape) != 4: + continue - if n.op_type == "MaxPool": - k = get_by_name(n.attribute, "kernel_shape").ints[-1] - stride = get_by_name(n.attribute, "strides").ints[-1] - # assumed datalayout + # extract pool parameters + if node.op_type == "MaxPool": + kh, kw = list(get_by_name(node.attribute, "kernel_shape").ints) + sh, sw = list(get_by_name(node.attribute, "strides").ints) dlayout = "NCHW" - elif n.op_type == "QuantAvgPool2d": - inst = getCustomOp(n) - k = inst.get_nodeattr("kernel") - stride = inst.get_nodeattr("stride") + elif node.op_type == "QuantAvgPool2d": + inst = getCustomOp(node) + # QuantAvgPool2d has a single scalar attribute + # for kernel size and stride (implicit square) + kh = kw = inst.get_nodeattr("kernel") + sh = sw = inst.get_nodeattr("stride") dlayout = inst.get_nodeattr("data_layout") - elif n.op_type == "MaxPoolNHWC": - inst = getCustomOp(n) - k_shape = inst.get_nodeattr("kernel_shape") - strides = inst.get_nodeattr("strides") - assert k_shape[0] == k_shape[1] - assert strides[0] == strides[1] - k = k_shape[0] - stride = strides[0] + elif node.op_type == "MaxPoolNHWC": + inst = getCustomOp(node) + kh, kw = inst.get_nodeattr("kernel_shape") + sh, sw = inst.get_nodeattr("strides") dlayout = "NHWC" try: - pad = get_by_name(n.attribute, "pads").ints[-1] + pad = list(get_by_name(node.attribute, "pads").ints) except AttributeError: - pad = 0 - - node_input = n.input[0] - node_output = n.output[0] - idt = model.get_tensor_datatype(node_input) + pad = [0, 0, 0, 0] if not idt.is_integer(): continue - if k < stride: + if (kh < sh) or (kw < sw): + # TODO check/implement swg support continue - elif k == stride: - warnings.warn( - n.name - + """: Inferring Pool_Batch node for k == stride. - This case can be optimized. - For example, for MaxPool run InferStreamingMaxPool before - InferPool_Batch """ - ) odt = model.get_tensor_datatype(node_output) if dlayout == "NCHW": - ifm_ch = model.get_tensor_shape(n.input[0])[1] + _, ifm_ch, ifm_h, ifm_w = ishape + _, ofm_ch, ofm_h, ofm_w = oshape + elif dlayout == "NHWC": + _, ifm_h, ifm_w, ifm_ch = ishape + _, ofm_h, ofm_w, ofm_ch = oshape else: - ifm_ch = model.get_tensor_shape(n.input[0])[-1] - ofm_ch = ifm_ch - ifm_dim = model.get_tensor_shape(n.input[0])[-2] - ofm_dim = model.get_tensor_shape(n.output[0])[-2] + raise Exception("Unknown dlayout: " + str(dlayout)) # if data layout NCHW, we need transpose nodes surrounding # the hls layer @@ -449,7 +444,7 @@ class InferPool_Batch(Transformation): inp_trans_out = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - (1, ifm_dim, ifm_dim, ifm_ch), # NHWC + (1, ifm_h, ifm_w, ifm_ch), # NHWC ) graph.value_info.append(inp_trans_out) inp_trans_out = inp_trans_out.name @@ -458,7 +453,7 @@ class InferPool_Batch(Transformation): pool_output = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - (1, ofm_dim, ofm_dim, ofm_ch), + (1, ofm_h, ofm_w, ofm_ch), ) graph.value_info.append(pool_output) pool_output = pool_output.name @@ -467,7 +462,7 @@ class InferPool_Batch(Transformation): im2col_out = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - (1, ofm_dim, ofm_dim, ifm_ch * k * k), + (1, ofm_h, ofm_w, ifm_ch * kh * kw), ) graph.value_info.append(im2col_out) im2col_out = im2col_out.name @@ -485,24 +480,26 @@ class InferPool_Batch(Transformation): pool_output = node_output accum_bits = 0 - pool_size_param = k + pool_size_param = kh # TODO fix pad_value = 0 - if n.op_type in ["MaxPool", "MaxPoolNHWC"]: + if node.op_type in ["MaxPool", "MaxPoolNHWC"]: pool_fxn = "MaxPool" odt = idt pad_value = idt.min() - elif n.op_type == "QuantAvgPool2d": + elif node.op_type == "QuantAvgPool2d": assert odt.is_integer(), """Output data type for QuantAvgPool2d needs to be integer""" assert pad == 0, "Padding is not supported for QuantAvgPool2d" - inst = getCustomOp(n) + inst = getCustomOp(node) pool_fxn = "QuantAvgPool" pool_size_param = inst.get_shifts() accum_bits = inst.get_accum_size() else: raise Exception( - "pad_value and pool_fxn not configured for {}".format(n.op_type) + "pad_value and pool_fxn not configured for {}".format( + node.op_type + ) ) # format input tensor @@ -511,13 +508,13 @@ class InferPool_Batch(Transformation): [im2col_in], [im2col_out], domain="finn.custom_op.general", - stride=[stride, stride], - kernel_size=[k, k], - pad_amount=[pad, pad, pad, pad], + stride=[sh, sw], + kernel_size=[kh, kw], + pad_amount=pad, pad_value=pad_value, depthwise=1, - input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch), - name="Im2Col_" + n.name, + input_shape="(1,{},{},{})".format(ifm_h, ifm_w, ifm_ch), + name="Im2Col_" + node.name, ) # Warning PE has to be equal to ifm_ch until Im2Col is replaced by @@ -534,13 +531,13 @@ class InferPool_Batch(Transformation): OutputDataType=odt.name, Channels=ifm_ch, PE=ifm_ch, - KernelSize=[k, k], + KernelSize=[kh, kw], Function=pool_fxn, - OutImgDim=ofm_dim, + OutImgDim=ofm_h, # TODO fix AccumBits=accum_bits, Size=pool_size_param, BatchSize=1, - name="Pool_Batch_" + n.name, + name="Pool_Batch_" + node.name, ) if dlayout == "NCHW": @@ -559,7 +556,7 @@ class InferPool_Batch(Transformation): graph.node.insert(node_ind, im2col_node) graph.node.insert(node_ind + 1, pool_node) # remove old node - graph.node.remove(n) + graph.node.remove(node) graph_modified = True if graph_modified: