diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh
index b312737c317517ca0ab19c74cf22284b5977b661..f701952d5d64e5d9b95aa59170b492fe7722ae02 100644
--- a/docker/finn_entrypoint.sh
+++ b/docker/finn_entrypoint.sh
@@ -15,7 +15,7 @@ gecho () {
 # the repos themselves are cloned in the Dockerfile
 BREVITAS_COMMIT=f9a27226d4acf1661dd38bc449f71f89e0983cce
 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
-HLSLIB_COMMIT=8f9f2018762f654f196b666838aeaf6fc730ad9a
+HLSLIB_COMMIT=cfafe11a93b79ab1af7529d68f08886913a6466e
 PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f
 PYNQSHELL_COMMIT=0c82a61b0ec1a07fa275a14146233824ded7a13d
 OMX_COMMIT=1bae737669901e762f581af73348332b5c4b2ada
diff --git a/src/finn/custom_op/fpgadataflow/pool_batch.py b/src/finn/custom_op/fpgadataflow/pool_batch.py
index c7edc24d0e24eef1154293caca2519ab3aa68358..801a634fdba1cd5e16c7c211175c1e7380bf0070 100644
--- a/src/finn/custom_op/fpgadataflow/pool_batch.py
+++ b/src/finn/custom_op/fpgadataflow/pool_batch.py
@@ -39,16 +39,18 @@ class Pool_Batch(HLSCustomOp):
     """Class that corresponds to finn-hlslib Pool_batch function.
     Requires ConvolutionInputGenerator(depthwise == 1) to format its input
 
-    TODO: explain input shape (to reuse im2col code)
     Input shape (BatchSize,OutImgDim,OutImgDim,KernelSize^2*Channels)
     Output shape (BatchSize,OutImgDim,OutImgDim,Channels)
 
-    # note: the actual data layout produced by the hlslib kernels is different
-    # for depthwise ops.
-    # * depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/PE, K, K, PE)
+    Notes:
+    # The input shape was chosen to be compatible with im2col (only true when there
+    is not folding).
+
+    # The actual data layout produced by the hlslib kernels is different
+    for depthwise ops.
+     * depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/PE, K, K, PE)
 
     Channels can be folded using PE (SIMD from the input perspective)
-    TODO: doc
     """
 
     def get_nodeattr_types(self):
@@ -63,7 +65,10 @@ class Pool_Batch(HLSCustomOp):
             "Function": ("s", True, ""),
             "OutImgDim": ("i", True, 0),
             # FINN DataTypes for inputs/outputs
-            "dataType": ("s", True, ""),
+            "InputDataType": ("s", True, ""),
+            "OutputDataType": ("s", True, ""),
+            "AccumBits": ("i", False, 0),
+            "Size": ("i", False, 1),
             "BatchSize": ("i", False, 1),
         }
 
@@ -72,17 +77,28 @@ class Pool_Batch(HLSCustomOp):
 
     def get_input_datatype(self):
         """Returns FINN DataType of input."""
-        return DataType[self.get_nodeattr("dataType")]
+        return DataType[self.get_nodeattr("InputDataType")]
 
     def get_output_datatype(self):
         """Returns FINN DataType of output."""
         fxn = self.get_nodeattr("Function")
+        odt = DataType[self.get_nodeattr("OutputDataType")]
+
         if fxn == "MaxPool":
             # Same as input
-            return DataType[self.get_nodeattr("dataType")]
+            idt = DataType[self.get_nodeattr("InputDataType")]
+            assert odt == idt, "In datatype must be equal to out datatype for Maxpool"
+        elif fxn == "QuantAvgPool":
+            idt = DataType[self.get_nodeattr("InputDataType")]
+            assert (
+                idt.signed() == odt.signed()
+            ), """QuantAvgPool: Can't mix signed
+            and unsigned datatypes"""
         else:
             raise Exception("Pool_Batch doesn't currently support " + fxn)
 
+        return odt
+
     def get_normal_input_shape(self):
         ifm_ch = self.get_nodeattr("Channels")
         odim = self.get_nodeattr("OutImgDim")
@@ -123,19 +139,14 @@ class Pool_Batch(HLSCustomOp):
     def get_instream_width(self):
         dt_bits = self.get_input_datatype().bitwidth()
         pe = self.get_nodeattr("PE")
-        # ofm_ch = self.get_nodeattr("Channels")
-        # k = self.get_nodeattr("KernelSize")
-        # assert ifm_ch % pe == 0, "PE must divide input channels"
-        # simd = int(ifm_ch/pe)
         in_width = int(dt_bits * pe)
         return in_width
 
     def get_outstream_width(self):
-        fxn = self.get_nodeattr("Function")
-        if fxn == "MaxPool":
-            return self.get_instream_width()
-        else:
-            raise Exception("Pool_Batch doesn't currently support " + fxn)
+        dt_bits = self.get_output_datatype().bitwidth()
+        pe = self.get_nodeattr("PE")
+        out_width = int(dt_bits * pe)
+        return out_width
 
     def make_shape_compatible_op(self, model):
         exp_ishape = self.get_normal_input_shape()
@@ -187,7 +198,7 @@ class Pool_Batch(HLSCustomOp):
 
         # check supported function
         fnx = self.get_nodeattr("Function")
-        if fnx == "MaxPool":
+        if fnx in ["MaxPool", "QuantAvgPool"]:
             info_messages.append(
                 "Attribute Function contains a supported pool function"
             )
@@ -251,7 +262,8 @@ class Pool_Batch(HLSCustomOp):
         i_hls_dt = idt.get_hls_datatype_str()
         odt = self.get_output_datatype()
         o_hls_dt = odt.get_hls_datatype_str()
-
+        size = self.get_nodeattr("Size")
+        accum_bits = self.get_nodeattr("AccumBits")
         self.code_gen_dict["$DOCOMPUTE$"] = []
 
         fxn = self.get_nodeattr("Function")
@@ -259,6 +271,16 @@ class Pool_Batch(HLSCustomOp):
             self.code_gen_dict["$DOCOMPUTE$"] += [
                 "MaxPoolFunction<{},KernelSize> pool_fxn;".format(i_hls_dt)
             ]
+        elif fxn == "QuantAvgPool":
+            if idt.signed():
+                act_hls_dt = "ap_int<{}>".format(accum_bits)
+            else:
+                act_hls_dt = "ap_uint<{}>".format(accum_bits)
+            self.code_gen_dict["$DOCOMPUTE$"] += [
+                "QuantAvgPoolFunction<{},{},{}> pool_fxn;".format(
+                    act_hls_dt, o_hls_dt, size
+                )
+            ]
         else:
             raise Exception("Pool_Batch doesn't currently support " + fxn)
 
@@ -369,7 +391,7 @@ class Pool_Batch(HLSCustomOp):
             super().reset_rtlsim(sim)
             super().toggle_clk(sim)
             rtlsim_output = self.rtlsim(sim, rtlsim_inp)
-            odt = export_idt
+            odt = self.get_output_datatype()
             target_bits = odt.bitwidth()
             packed_bits = self.get_outstream_width()
             out_npy_path = "{}/output.npy".format(code_gen_dir)
diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
index fb5c78bc0c8419ba519c5c3113d9b0c7ae2dd3b7..28d01069264d883f3afc400808470f5f303be799 100644
--- a/src/finn/custom_op/quantavgpool2d.py
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -75,6 +75,19 @@ class QuantAvgPool2d(CustomOp):
             raise Exception("Unsupported output datatype for QuantAvgPool2d")
         model.set_tensor_datatype(node.output[0], dtype)
 
+    def get_accum_size(self):
+        ibits = self.get_nodeattr("ibits")
+        k = self.get_nodeattr("kernel")
+        max_value = 2 ** ibits - 1
+        max_value = max_value * k * k
+        max_bit_width = int(max_value).bit_length()
+        return max_bit_width
+
+    def get_shifts(self):
+        shift_bits = self.get_accum_size() - self.get_nodeattr("obits")
+        shift_bits = shift_bits if shift_bits >= 0 else 0
+        return shift_bits
+
     def execute_node(self, context, graph):
         # create a standard average pooling node to help calculate the result
         node = self.onnx_node
@@ -107,12 +120,7 @@ class QuantAvgPool2d(CustomOp):
         result_temp = sess.run(None, idict)
         # remove scaling introduced by average
         result_temp = result_temp[0] * (k * k)
-        ibits = self.get_nodeattr("ibits")
-        max_value = 2 ** ibits - 1
-        max_value = max_value * k * k
-        max_bit_width = int(max_value).bit_length()
-        shift_bits = max_bit_width - self.get_nodeattr("obits")
-        result = np.right_shift(result_temp.astype(int), shift_bits)
+        result = np.right_shift(result_temp.astype(int), self.get_shifts())
         if self.get_nodeattr("data_layout") == "NHWC":
             result = result.transpose(0, 2, 3, 1)
         context[node.output[0]] = result.astype(np.float32)
diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 700db1436744585ada72455a568d75569740a8b3..877b4a0dc9bfb0e21dfba0ac885af0fc110c9fe1 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -209,13 +209,16 @@ class InferPool_Batch(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type in ["MaxPool"]:
+            if n.op_type in ["MaxPool", "QuantAvgPool2d"]:
                 # extract pool parameters
-                k = get_by_name(n.attribute, "kernel_shape").ints[-1]
-                stride = get_by_name(n.attribute, "strides").ints[-1]
 
-                if k <= stride:
-                    continue
+                if n.op_type == "MaxPool":
+                    k = get_by_name(n.attribute, "kernel_shape").ints[-1]
+                    stride = get_by_name(n.attribute, "strides").ints[-1]
+                elif n.op_type == "QuantAvgPool2d":
+                    inst = getCustomOp(n)
+                    k = inst.get_nodeattr("kernel")
+                    stride = inst.get_nodeattr("stride")
 
                 try:
                     pad = get_by_name(n.attribute, "pads").ints[-1]
@@ -225,10 +228,21 @@ class InferPool_Batch(Transformation):
                 node_input = n.input[0]
                 node_output = n.output[0]
                 idt = model.get_tensor_datatype(node_input)
+
                 if not idt.is_integer():
                     continue
 
-                # odt = model.get_tensor_datatype(node_output)
+                if k < stride:
+                    continue
+                elif k == stride:
+                    warnings.warn(
+                        """Inferring Pool_Batch node for k == stride.
+                        This case can be optimized.
+                        For example, for MaxPool run InferStreamingMaxPool before
+                        InferPool_Batch """
+                    )
+
+                odt = model.get_tensor_datatype(node_output)
 
                 ifm_ch = model.get_tensor_shape(n.input[0])[1]  # assume NCHW
                 ofm_ch = ifm_ch
@@ -268,9 +282,22 @@ class InferPool_Batch(Transformation):
                     "Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]
                 )
 
+                accum_bits = 0
+                pool_size_param = k
+                pad_value = 0
                 if n.op_type == "MaxPool":
                     pool_fxn = "MaxPool"
+                    odt = idt
                     pad_value = idt.min()
+                elif n.op_type == "QuantAvgPool2d":
+                    assert odt.is_integer(), """Output data type for QuantAvgPool2d
+                    needs to be integer"""
+                    assert pad == 0, "Padding is not supported for QuantAvgPool2d"
+                    inst = getCustomOp(n)
+                    pool_fxn = "QuantAvgPool"
+                    pool_size_param = inst.get_shifts()
+                    accum_bits = inst.get_accum_size()
+
                 else:
                     raise Exception(
                         "pad_value and pool_fxn not configured for {}".format(n.op_type)
@@ -300,12 +327,15 @@ class InferPool_Batch(Transformation):
                     [pool_output],
                     domain="finn",
                     backend="fpgadataflow",
-                    dataType=idt.name,
+                    InputDataType=idt.name,
+                    OutputDataType=odt.name,
                     Channels=ifm_ch,
                     PE=ifm_ch,
                     KernelSize=k,
                     Function=pool_fxn,
                     OutImgDim=ofm_dim,
+                    AccumBits=accum_bits,
+                    Size=pool_size_param,
                     BatchSize=1,
                 )
 
diff --git a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
index c9f78dcea1a1ce364d0657ad64de7d440d41b822..aba973051cb14e3e428e4de72a57924884c831de 100644
--- a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
+++ b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
@@ -77,27 +77,63 @@ def make_single_maxpool_modelwrapper(k, stride, pad, ifm_ch, ifm_dim, ofm_dim, i
     return model
 
 
+def make_single_quantavpool_modelwrapper(k, stride, ifm_ch, ifm_dim, ofm_dim, idt, odt):
+    inp = helper.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+    )
+    outp = helper.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, ifm_ch, ofm_dim, ofm_dim]
+    )
+
+    mp_node = helper.make_node(
+        "QuantAvgPool2d",
+        ["inp"],
+        ["outp"],
+        domain="finn",
+        stride=stride,
+        kernel=k,
+        ibits=idt.bitwidth(),
+        obits=odt.bitwidth(),
+        signed=1 if idt.signed() else 0,
+        data_layout="NCHW",
+    )
+    graph = helper.make_graph(
+        nodes=[mp_node], name="mp_graph", inputs=[inp], outputs=[outp]
+    )
+
+    model = helper.make_model(graph, producer_name="mp-model")
+    model = ModelWrapper(model)
+
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype("outp", odt)
+    model = model.transform(InferShapes())
+
+    return model
+
+
 def prepare_inputs(input_tensor):
     return {"inp": input_tensor}
 
 
 # input datatype
-@pytest.mark.parametrize("idt", [DataType.UINT4, DataType.INT4])
+@pytest.mark.parametrize("idt", [DataType.UINT4, DataType.INT4, DataType.INT8])
+# output datatype
+@pytest.mark.parametrize("odt", [DataType.UINT4, DataType.INT4])
 # pool configuration:                   ( k,stride, pad, ifm_dim )
-@pytest.mark.parametrize(
-    "pool_config", [(3, 2, 0, 5), (3, 2, 1, 5), (2, 2, 0, 8), (5, 2, 2, 7)]
-)
+@pytest.mark.parametrize("pool_config", [(7, 7, 0, 7), (3, 2, 1, 5)])
 # input channels
-@pytest.mark.parametrize("ifm_ch", [1, 4, 20])
+@pytest.mark.parametrize("ifm_ch", [1, 4])
 # number of out channel computed in parallel
-@pytest.mark.parametrize("pe", [1, 4, 20])
+@pytest.mark.parametrize("pe", [1, 2, 4])
+# pool type
+@pytest.mark.parametrize("op_type", ["QuantAvgPool2d", "MaxPool"])
 # execution mode
 @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
-# pool type
-@pytest.mark.parametrize("op_type", ["MaxPool"])
 @pytest.mark.slow
 @pytest.mark.vivado
-def test_convert_to_hls_pool_batch(idt, pool_config, ifm_ch, pe, exec_mode, op_type):
+def test_convert_to_hls_pool_batch(
+    idt, odt, pool_config, ifm_ch, pe, op_type, exec_mode
+):
     k, stride, pad, ifm_dim = pool_config
 
     if ifm_ch % pe != 0:
@@ -113,9 +149,25 @@ def test_convert_to_hls_pool_batch(idt, pool_config, ifm_ch, pe, exec_mode, op_t
     # prepare input data
     input_dict = prepare_inputs(x)
     if op_type == "MaxPool":
+        # if idt.signed():
+        #     pytest.skip("""No support for signed input (see accu initialization
+        #         in Pool_batch HLSLIB function). Skipping""")
+
+        if idt != odt:
+            pytest.skip("Skipping Maxpool with idt != odt")
+
         model = make_single_maxpool_modelwrapper(
             k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt
         )
+    elif op_type == "QuantAvgPool2d":
+        if pad != 0:
+            pytest.skip("No padding support for QuantAvgPool2d. Skipping")
+
+        if idt.signed() != odt.signed():
+            pytest.skip("Skipping QuantAvgPool2d with idt.signed() != odt.signed()")
+        model = make_single_quantavpool_modelwrapper(
+            k, stride, ifm_ch, ifm_dim, ofm_dim, idt, odt
+        )
     else:
         assert False, "{} is not a supported op_type".format(op_type)
 
@@ -151,7 +203,7 @@ def test_convert_to_hls_pool_batch(idt, pool_config, ifm_ch, pe, exec_mode, op_t
     # execute new_model
     y_produced = oxe.execute_onnx(new_model, input_dict)["outp"]
     assert (y_produced == y_expected).all()
-    if stride != k:
+    if stride <= k:
         if pad == 0 or ifm_ch == pe:
             assert len(new_model.graph.node) == 4
         else: