diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py index 44e7d52a3a1af373e06058fef55f7aa3aba51e0f..4841ce77baca5bd8c69c8acf6b1b102d17fb1d69 100644 --- a/src/finn/custom_op/quantavgpool2d.py +++ b/src/finn/custom_op/quantavgpool2d.py @@ -3,7 +3,6 @@ from onnx import TensorProto, helper import onnxruntime as rt from finn.custom_op import CustomOp -from finn.custom_op.im2col import compute_conv_output_dim class QuantAvgPool2d(CustomOp): @@ -21,31 +20,14 @@ class QuantAvgPool2d(CustomOp): def make_shape_compatible_op(self, model): node = self.onnx_node - inp = node.input[0] - ishape = model.get_tensor_shape(inp) - # we assume that the shape is (NCHW) and H=W - assert len(ishape) == 4, "Unexpected input shape for QuantAvgPool2d" - assert ( - ishape[2] == ishape[3] - ), "QuantAvgPool2d for non-square images unsupported" - ch = ishape[1] - idim = ishape[2] k = self.get_nodeattr("kernel") - stride = self.get_nodeattr("stride") - odim = compute_conv_output_dim(idim, k, stride) - - # implement tensor with correct shape - values = np.random.randn(1, ch, odim, odim).astype(np.float32) + s = self.get_nodeattr("stride") return helper.make_node( - "Constant", - inputs=[], + "AveragePool", + inputs=[node.input[0]], outputs=[node.output[0]], - value=helper.make_tensor( - name="const_tensor", - data_type=TensorProto.FLOAT, - dims=values.shape, - vals=values.flatten().astype(float), - ), + kernel_shape=[k, k], + strides=[s, s], ) def infer_node_datatype(self, model):