diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
index 3bc328a9f4f6670041d33491d58af6c553bafac9..5a81d8081d13853cb2b660ead980a7de89821d89 100644
--- a/src/finn/custom_op/quantavgpool2d.py
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -4,6 +4,7 @@ import onnxruntime as rt
 
 from finn.custom_op import CustomOp
 from finn.core.datatype import DataType
+from finn.custom_op.maxpoolnhwc import compute_pool_output_dim
 
 
 class QuantAvgPool2d(CustomOp):
@@ -17,18 +18,44 @@ class QuantAvgPool2d(CustomOp):
             "ibits": ("i", True, 1),
             "obits": ("i", True, 1),
             "signed": ("i", True, 0),
+            "data_layout": ("s", False, "NCHW"),
         }
 
     def make_shape_compatible_op(self, model):
         node = self.onnx_node
+        iname = node.input[0]
+        ishape = model.get_tensor_shape(iname)
         k = self.get_nodeattr("kernel")
         s = self.get_nodeattr("stride")
+        data_layout = self.get_nodeattr("data_layout")
+        if data_layout == "NCHW":
+            (n, c, hi, wi) = ishape
+            ho = compute_pool_output_dim(hi, k, s)
+            wo = compute_pool_output_dim(wi, k, s)
+            oshape = (n, c, ho, wo)
+        elif data_layout == "NHWC":
+            (n, hi, wi, c) = ishape
+            ho = compute_pool_output_dim(hi, k, s)
+            wo = compute_pool_output_dim(wi, k, s)
+            oshape = (n, ho, wo, c)
+        else:
+            raise Exception(
+                """Datalayout for QuantAvgPool2d is set to an unvalid value.
+                    Has to be set to "NCHW" or "NHWC"."""
+            )
+
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
         return helper.make_node(
-            "AveragePool",
-            inputs=[node.input[0]],
+            "Constant",
+            inputs=[],
             outputs=[node.output[0]],
-            kernel_shape=[k, k],
-            strides=[s, s],
+            value=helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.FLOAT,
+                dims=values.shape,
+                vals=values.flatten().astype(float),
+            ),
         )
 
     def infer_node_datatype(self, model):
@@ -48,8 +75,12 @@ class QuantAvgPool2d(CustomOp):
         node = self.onnx_node
         k = self.get_nodeattr("kernel")
         s = self.get_nodeattr("stride")
-        ishape = context[node.input[0]].shape
+        inp_values = context[node.input[0]]
         oshape = context[node.output[0]].shape
+        if self.get_nodeattr("data_layout") == "NHWC":
+            inp_values = inp_values.transpose(0, 3, 1, 2)
+            oshape = (context[node.output[0]]).transpose(0, 3, 1, 2).shape
+        ishape = inp_values.shape
         inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape)
         outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape)
         node_avgpool = helper.make_node(
@@ -66,7 +97,7 @@ class QuantAvgPool2d(CustomOp):
             outputs=[outp],
         )
         model_avgpool = helper.make_model(graph_avgpool)
-        idict = {node.input[0]: context[node.input[0]]}
+        idict = {node.input[0]: inp_values}
         sess = rt.InferenceSession(model_avgpool.SerializeToString())
         result_temp = sess.run(None, idict)
         # remove scaling introduced by average
@@ -77,7 +108,16 @@ class QuantAvgPool2d(CustomOp):
         max_bit_width = int(max_value).bit_length()
         shift_bits = max_bit_width - self.get_nodeattr("obits")
         result = np.right_shift(result_temp.astype(int), shift_bits)
+        if self.get_nodeattr("data_layout") == "NHWC":
+            result = result.transpose(0, 2, 3, 1)
         context[node.output[0]] = result.astype(np.float32)
 
     def verify_node(self):
-        pass
+        info_messages = []
+        # verify that "domain" is set to "finn"
+        domain_value = self.onnx_node.domain
+        if domain_value == "finn":
+            info_messages.append("Attribute domain is set correctly")
+        else:
+            info_messages.append('Attribute domain should be set to "finn"')
+        return info_messages