diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 54bd01e2baf479789baa06853b6a442ca5ee6ef5..0364b272e773a0f498f559cb768a4148224c3a3d 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -385,62 +385,57 @@ class InferPool_Batch(Transformation):
         graph = model.graph
         node_ind = 0
         graph_modified = False
-        for n in graph.node:
+        for node in graph.node:
             node_ind += 1
-            if n.op_type in ["MaxPool", "QuantAvgPool2d", "MaxPoolNHWC"]:
-                # extract pool parameters
+            if node.op_type in ["MaxPool", "QuantAvgPool2d", "MaxPoolNHWC"]:
+                node_input = node.input[0]
+                ishape = model.get_tensor_shape(node_input)
+                node_output = node.output[0]
+                idt = model.get_tensor_datatype(node_input)
+                oshape = model.get_tensor_shape(node_output)
+                # only support 4D input tensors (1D convs need extra dummy dim)
+                if len(ishape) != 4:
+                    continue
 
-                if n.op_type == "MaxPool":
-                    k = get_by_name(n.attribute, "kernel_shape").ints[-1]
-                    stride = get_by_name(n.attribute, "strides").ints[-1]
-                    # assumed datalayout
+                # extract pool parameters
+                if node.op_type == "MaxPool":
+                    kh, kw = list(get_by_name(node.attribute, "kernel_shape").ints)
+                    sh, sw = list(get_by_name(node.attribute, "strides").ints)
                     dlayout = "NCHW"
-                elif n.op_type == "QuantAvgPool2d":
-                    inst = getCustomOp(n)
-                    k = inst.get_nodeattr("kernel")
-                    stride = inst.get_nodeattr("stride")
+                elif node.op_type == "QuantAvgPool2d":
+                    inst = getCustomOp(node)
+                    # QuantAvgPool2d has a single scalar attribute
+                    # for kernel size and stride (implicit square)
+                    kh = kw = inst.get_nodeattr("kernel")
+                    sh = sw = inst.get_nodeattr("stride")
                     dlayout = inst.get_nodeattr("data_layout")
-                elif n.op_type == "MaxPoolNHWC":
-                    inst = getCustomOp(n)
-                    k_shape = inst.get_nodeattr("kernel_shape")
-                    strides = inst.get_nodeattr("strides")
-                    assert k_shape[0] == k_shape[1]
-                    assert strides[0] == strides[1]
-                    k = k_shape[0]
-                    stride = strides[0]
+                elif node.op_type == "MaxPoolNHWC":
+                    inst = getCustomOp(node)
+                    kh, kw = inst.get_nodeattr("kernel_shape")
+                    sh, sw = inst.get_nodeattr("strides")
                     dlayout = "NHWC"
                 try:
-                    pad = get_by_name(n.attribute, "pads").ints[-1]
+                    pad = list(get_by_name(node.attribute, "pads").ints)
                 except AttributeError:
-                    pad = 0
-
-                node_input = n.input[0]
-                node_output = n.output[0]
-                idt = model.get_tensor_datatype(node_input)
+                    pad = [0, 0, 0, 0]
 
                 if not idt.is_integer():
                     continue
 
-                if k < stride:
+                if (kh < sh) or (kw < sw):
+                    # TODO check/implement swg support
                     continue
-                elif k == stride:
-                    warnings.warn(
-                        n.name
-                        + """: Inferring Pool_Batch node for k == stride.
-                        This case can be optimized.
-                        For example, for MaxPool run InferStreamingMaxPool before
-                        InferPool_Batch """
-                    )
 
                 odt = model.get_tensor_datatype(node_output)
 
                 if dlayout == "NCHW":
-                    ifm_ch = model.get_tensor_shape(n.input[0])[1]
+                    _, ifm_ch, ifm_h, ifm_w = ishape
+                    _, ofm_ch, ofm_h, ofm_w = oshape
+                elif dlayout == "NHWC":
+                    _, ifm_h, ifm_w, ifm_ch = ishape
+                    _, ofm_h, ofm_w, ofm_ch = oshape
                 else:
-                    ifm_ch = model.get_tensor_shape(n.input[0])[-1]
-                ofm_ch = ifm_ch
-                ifm_dim = model.get_tensor_shape(n.input[0])[-2]
-                ofm_dim = model.get_tensor_shape(n.output[0])[-2]
+                    raise Exception("Unknown dlayout: " + str(dlayout))
 
                 # if data layout NCHW, we need transpose nodes surrounding
                 # the hls layer
@@ -449,7 +444,7 @@ class InferPool_Batch(Transformation):
                     inp_trans_out = helper.make_tensor_value_info(
                         model.make_new_valueinfo_name(),
                         TensorProto.FLOAT,
-                        (1, ifm_dim, ifm_dim, ifm_ch),  # NHWC
+                        (1, ifm_h, ifm_w, ifm_ch),  # NHWC
                     )
                     graph.value_info.append(inp_trans_out)
                     inp_trans_out = inp_trans_out.name
@@ -458,7 +453,7 @@ class InferPool_Batch(Transformation):
                     pool_output = helper.make_tensor_value_info(
                         model.make_new_valueinfo_name(),
                         TensorProto.FLOAT,
-                        (1, ofm_dim, ofm_dim, ofm_ch),
+                        (1, ofm_h, ofm_w, ofm_ch),
                     )
                     graph.value_info.append(pool_output)
                     pool_output = pool_output.name
@@ -467,7 +462,7 @@ class InferPool_Batch(Transformation):
                 im2col_out = helper.make_tensor_value_info(
                     model.make_new_valueinfo_name(),
                     TensorProto.FLOAT,
-                    (1, ofm_dim, ofm_dim, ifm_ch * k * k),
+                    (1, ofm_h, ofm_w, ifm_ch * kh * kw),
                 )
                 graph.value_info.append(im2col_out)
                 im2col_out = im2col_out.name
@@ -485,24 +480,26 @@ class InferPool_Batch(Transformation):
                     pool_output = node_output
 
                 accum_bits = 0
-                pool_size_param = k
+                pool_size_param = kh  # TODO fix
                 pad_value = 0
-                if n.op_type in ["MaxPool", "MaxPoolNHWC"]:
+                if node.op_type in ["MaxPool", "MaxPoolNHWC"]:
                     pool_fxn = "MaxPool"
                     odt = idt
                     pad_value = idt.min()
-                elif n.op_type == "QuantAvgPool2d":
+                elif node.op_type == "QuantAvgPool2d":
                     assert odt.is_integer(), """Output data type for QuantAvgPool2d
                     needs to be integer"""
                     assert pad == 0, "Padding is not supported for QuantAvgPool2d"
-                    inst = getCustomOp(n)
+                    inst = getCustomOp(node)
                     pool_fxn = "QuantAvgPool"
                     pool_size_param = inst.get_shifts()
                     accum_bits = inst.get_accum_size()
 
                 else:
                     raise Exception(
-                        "pad_value and pool_fxn not configured for {}".format(n.op_type)
+                        "pad_value and pool_fxn not configured for {}".format(
+                            node.op_type
+                        )
                     )
 
                 # format input tensor
@@ -511,13 +508,13 @@ class InferPool_Batch(Transformation):
                     [im2col_in],
                     [im2col_out],
                     domain="finn.custom_op.general",
-                    stride=[stride, stride],
-                    kernel_size=[k, k],
-                    pad_amount=[pad, pad, pad, pad],
+                    stride=[sh, sw],
+                    kernel_size=[kh, kw],
+                    pad_amount=pad,
                     pad_value=pad_value,
                     depthwise=1,
-                    input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch),
-                    name="Im2Col_" + n.name,
+                    input_shape="(1,{},{},{})".format(ifm_h, ifm_w, ifm_ch),
+                    name="Im2Col_" + node.name,
                 )
 
                 # Warning PE has to be equal to ifm_ch until Im2Col is replaced by
@@ -534,13 +531,13 @@ class InferPool_Batch(Transformation):
                     OutputDataType=odt.name,
                     Channels=ifm_ch,
                     PE=ifm_ch,
-                    KernelSize=[k, k],
+                    KernelSize=[kh, kw],
                     Function=pool_fxn,
-                    OutImgDim=ofm_dim,
+                    OutImgDim=ofm_h,  # TODO fix
                     AccumBits=accum_bits,
                     Size=pool_size_param,
                     BatchSize=1,
-                    name="Pool_Batch_" + n.name,
+                    name="Pool_Batch_" + node.name,
                 )
 
                 if dlayout == "NCHW":
@@ -559,7 +556,7 @@ class InferPool_Batch(Transformation):
                     graph.node.insert(node_ind, im2col_node)
                     graph.node.insert(node_ind + 1, pool_node)
                 # remove old node
-                graph.node.remove(n)
+                graph.node.remove(node)
                 graph_modified = True
 
         if graph_modified: