diff --git a/src/finn/custom_op/maxpoolnhwc.py b/src/finn/custom_op/maxpoolnhwc.py index 7586a859c17db690080f790e4ee5dae9610336cd..c623e40075e0ed6836dc9494ee5effb4539a46af 100644 --- a/src/finn/custom_op/maxpoolnhwc.py +++ b/src/finn/custom_op/maxpoolnhwc.py @@ -32,15 +32,49 @@ from onnx import helper, TensorProto from finn.core.modelwrapper import ModelWrapper +def compute_pool_output_dim(ifm_dim, k, stride, pad=0): + "Return spatial output dimension size for pooling with given params." + return int(((ifm_dim + 2 * pad - k) / stride) + 1) + + class MaxPoolNHWC(CustomOp): # a MaxPool node, but using the NHWC data layout def get_nodeattr_types(self): # no specific attributes for MaxPoolNHWC - return {} + return { + "kernel_shape": ("ints", True, []), + "pads": ("ints", True, []), + "strides": ("ints", True, []), + } def make_shape_compatible_op(self, model): - raise Exception("MaxPoolNHWC does not yet support shape inference") + node = self.onnx_node + iname = node.input[0] + ishape = model.get_tensor_shape(iname) + kernel_shape = self.get_nodeattr("kernel_shape") + pads = self.get_nodeattr("pads") + strides = self.get_nodeattr("strides") + assert len(kernel_shape) == 2, "Non-2D MaxPoolNHWC not supported" + assert pads[0] == pads[2], "Uneven padding not supported" + assert pads[1] == pads[3], "Uneven padding not supported" + (n, hi, wi, c) = ishape + ho = compute_pool_output_dim(hi, kernel_shape[0], strides[0], pads[0]) + wo = compute_pool_output_dim(wi, kernel_shape[1], strides[1], pads[2]) + oshape = (n, ho, wo, c) + # implement tensor with correct shape + values = np.random.randn(*oshape).astype(np.float32) + return helper.make_node( + "Constant", + inputs=[], + outputs=[self.onnx_node.output[0]], + value=helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=values.shape, + vals=values.flatten().astype(float), + ), + ) def infer_node_datatype(self, model): node = self.onnx_node