diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py index 075d807c0a7686d452ba57140e1fec2115954e01..0a848b85971030e5826a00291a0f1b305377d94b 100644 --- a/src/finn/custom_op/quantavgpool2d.py +++ b/src/finn/custom_op/quantavgpool2d.py @@ -3,6 +3,7 @@ from onnx import TensorProto, helper import onnxruntime as rt from finn.custom_op import CustomOp +from finn.core.datatype import DataType class QuantAvgPool2d(CustomOp): @@ -32,8 +33,14 @@ class QuantAvgPool2d(CustomOp): def infer_node_datatype(self, model): node = self.onnx_node - # data type stays the same - dtype = model.get_tensor_datatype(node.input[0]) + bw = self.get_nodeattr("obits") + if bw in [2,4,8,16,32]: + if self.get_nodeattr("signed") == 0: + dtype = DataType["UINT%d" % bw] + else: + dtype = DataType["INT%d" % bw] + else: + raise Exception("Unsupported output datatype for QuantAvgPool2d") model.set_tensor_datatype(node.output[0], dtype) def execute_node(self, context, graph):