diff --git a/src/finn/transformation/change_datalayout.py b/src/finn/transformation/change_datalayout.py
index 5b1e70052ccc28551ba6609e4dda72bf9450c2b0..d5b393a25e57122b059a44f70904a6dbe5bbaa3f 100644
--- a/src/finn/transformation/change_datalayout.py
+++ b/src/finn/transformation/change_datalayout.py
@@ -43,7 +43,10 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "QuantAvgPool2d":
+            if n.op_type == "QuantAvgPool2d" and (
+                get_by_name(n.attribute, "data_layout") is None
+                or get_by_name(n.attribute, "data_layout").s.decode("UTF-8") == "NCHW"
+            ):
                 graph_modified = True
                 node_input = n.input[0]
                 node_output = n.output[0]
@@ -78,7 +81,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
                     "Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]
                 )
                 quantavg_node = helper.make_node(
-                    "QuantAvgPool2dNHWC",
+                    "QuantAvgPool2d",
                     [inp_trans_out],
                     [quantavg_out],
                     domain="finn",
@@ -87,6 +90,7 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
                     ibits=ibits,
                     obits=obits,
                     signed=signed,
+                    data_layout="NHWC",
                 )
                 # NHWC -> NCHW
                 out_trans_node = helper.make_node(
@@ -99,5 +103,8 @@ class ChangeDataLayoutQuantAvgPool2d(Transformation):
                 # remove old nodes
                 graph.node.remove(n)
 
+                # set shapes
+                model.set_tensor_shape(inp_trans_out, (batchsize, idim, idim, channels))
+                model.set_tensor_shape(quantavg_out, (batchsize, odim, odim, channels))
         model = model.transform(InferShapes())
         return (model, graph_modified)