diff --git a/src/finn/custom_op/maxpoolnhwc.py b/src/finn/custom_op/maxpoolnhwc.py
index 7586a859c17db690080f790e4ee5dae9610336cd..c623e40075e0ed6836dc9494ee5effb4539a46af 100644
--- a/src/finn/custom_op/maxpoolnhwc.py
+++ b/src/finn/custom_op/maxpoolnhwc.py
@@ -32,15 +32,49 @@ from onnx import helper, TensorProto
 from finn.core.modelwrapper import ModelWrapper
 
 
+def compute_pool_output_dim(ifm_dim, k, stride, pad=0):
+    "Return spatial output dimension size for pooling with given params."
+    return int(((ifm_dim + 2 * pad - k) / stride) + 1)
+
+
 class MaxPoolNHWC(CustomOp):
     # a MaxPool node, but using the NHWC data layout
 
     def get_nodeattr_types(self):
         # no specific attributes for MaxPoolNHWC
-        return {}
+        return {
+            "kernel_shape": ("ints", True, []),
+            "pads": ("ints", True, []),
+            "strides": ("ints", True, []),
+        }
 
     def make_shape_compatible_op(self, model):
-        raise Exception("MaxPoolNHWC does not yet support shape inference")
+        node = self.onnx_node
+        iname = node.input[0]
+        ishape = model.get_tensor_shape(iname)
+        kernel_shape = self.get_nodeattr("kernel_shape")
+        pads = self.get_nodeattr("pads")
+        strides = self.get_nodeattr("strides")
+        assert len(kernel_shape) == 2, "Non-2D MaxPoolNHWC not supported"
+        assert pads[0] == pads[2], "Uneven padding not supported"
+        assert pads[1] == pads[3], "Uneven padding not supported"
+        (n, hi, wi, c) = ishape
+        ho = compute_pool_output_dim(hi, kernel_shape[0], strides[0], pads[0])
+        wo = compute_pool_output_dim(wi, kernel_shape[1], strides[1], pads[2])
+        oshape = (n, ho, wo, c)
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
+        return helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=[self.onnx_node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.FLOAT,
+                dims=values.shape,
+                vals=values.flatten().astype(float),
+            ),
+        )
 
     def infer_node_datatype(self, model):
         node = self.onnx_node