Skip to content
Snippets Groups Projects
Commit 10756cd8 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[CustomOp] add shape inference for MaxPoolNHWC

parent e9653644
No related branches found
No related tags found
No related merge requests found
......@@ -32,15 +32,49 @@ from onnx import helper, TensorProto
from finn.core.modelwrapper import ModelWrapper
def compute_pool_output_dim(ifm_dim, k, stride, pad=0):
"Return spatial output dimension size for pooling with given params."
return int(((ifm_dim + 2 * pad - k) / stride) + 1)
class MaxPoolNHWC(CustomOp):
# a MaxPool node, but using the NHWC data layout
def get_nodeattr_types(self):
# no specific attributes for MaxPoolNHWC
return {}
return {
"kernel_shape": ("ints", True, []),
"pads": ("ints", True, []),
"strides": ("ints", True, []),
}
def make_shape_compatible_op(self, model):
raise Exception("MaxPoolNHWC does not yet support shape inference")
node = self.onnx_node
iname = node.input[0]
ishape = model.get_tensor_shape(iname)
kernel_shape = self.get_nodeattr("kernel_shape")
pads = self.get_nodeattr("pads")
strides = self.get_nodeattr("strides")
assert len(kernel_shape) == 2, "Non-2D MaxPoolNHWC not supported"
assert pads[0] == pads[2], "Uneven padding not supported"
assert pads[1] == pads[3], "Uneven padding not supported"
(n, hi, wi, c) = ishape
ho = compute_pool_output_dim(hi, kernel_shape[0], strides[0], pads[0])
wo = compute_pool_output_dim(wi, kernel_shape[1], strides[1], pads[2])
oshape = (n, ho, wo, c)
# implement tensor with correct shape
values = np.random.randn(*oshape).astype(np.float32)
return helper.make_node(
"Constant",
inputs=[],
outputs=[self.onnx_node.output[0]],
value=helper.make_tensor(
name="const_tensor",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(float),
),
)
def infer_node_datatype(self, model):
node = self.onnx_node
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment