From 3acdd188095861d58b9f843b2f84fbd2c81652a0 Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Fri, 12 Jun 2020 09:43:30 +0100 Subject: [PATCH] [CustomOp] Change infer datatype function of QuantAvgPool2d --- src/finn/custom_op/quantavgpool2d.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py index 075d807c0..0a848b859 100644 --- a/src/finn/custom_op/quantavgpool2d.py +++ b/src/finn/custom_op/quantavgpool2d.py @@ -3,6 +3,7 @@ from onnx import TensorProto, helper import onnxruntime as rt from finn.custom_op import CustomOp +from finn.core.datatype import DataType class QuantAvgPool2d(CustomOp): @@ -32,8 +33,14 @@ class QuantAvgPool2d(CustomOp): def infer_node_datatype(self, model): node = self.onnx_node - # data type stays the same - dtype = model.get_tensor_datatype(node.input[0]) + bw = self.get_nodeattr("obits") + if bw in [2,4,8,16,32]: + if self.get_nodeattr("signed") == 0: + dtype = DataType["UINT%d" % bw] + else: + dtype = DataType["INT%d" % bw] + else: + raise Exception("Unsupported output datatype for QuantAvgPool2d") model.set_tensor_datatype(node.output[0], dtype) def execute_node(self, context, graph): -- GitLab