diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
index 075d807c0a7686d452ba57140e1fec2115954e01..0a848b85971030e5826a00291a0f1b305377d94b 100644
--- a/src/finn/custom_op/quantavgpool2d.py
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -3,6 +3,7 @@ from onnx import TensorProto, helper
 import onnxruntime as rt
 
 from finn.custom_op import CustomOp
+from finn.core.datatype import DataType
 
 
 class QuantAvgPool2d(CustomOp):
@@ -32,8 +33,14 @@ class QuantAvgPool2d(CustomOp):
 
     def infer_node_datatype(self, model):
         node = self.onnx_node
-        # data type stays the same
-        dtype = model.get_tensor_datatype(node.input[0])
+        bw = self.get_nodeattr("obits")
+        if bw in [2,4,8,16,32]:
+            if self.get_nodeattr("signed") == 0:
+                dtype = DataType["UINT%d" % bw]
+            else:
+                dtype = DataType["INT%d" % bw]
+        else:
+            raise Exception("Unsupported output datatype for QuantAvgPool2d")
         model.set_tensor_datatype(node.output[0], dtype)
 
     def execute_node(self, context, graph):