From 3acdd188095861d58b9f843b2f84fbd2c81652a0 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Fri, 12 Jun 2020 09:43:30 +0100
Subject: [PATCH] [CustomOp] Change infer datatype function of QuantAvgPool2d

---
 src/finn/custom_op/quantavgpool2d.py | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
index 075d807c0..0a848b859 100644
--- a/src/finn/custom_op/quantavgpool2d.py
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -3,6 +3,7 @@ from onnx import TensorProto, helper
 import onnxruntime as rt
 
 from finn.custom_op import CustomOp
+from finn.core.datatype import DataType
 
 
 class QuantAvgPool2d(CustomOp):
@@ -32,8 +33,14 @@ class QuantAvgPool2d(CustomOp):
 
     def infer_node_datatype(self, model):
         node = self.onnx_node
-        # data type stays the same
-        dtype = model.get_tensor_datatype(node.input[0])
+        bw = self.get_nodeattr("obits")
+        if bw in [2,4,8,16,32]:
+            if self.get_nodeattr("signed") == 0:
+                dtype = DataType["UINT%d" % bw]
+            else:
+                dtype = DataType["INT%d" % bw]
+        else:
+            raise Exception("Unsupported output datatype for QuantAvgPool2d")
         model.set_tensor_datatype(node.output[0], dtype)
 
     def execute_node(self, context, graph):
-- 
GitLab