Skip to content
Snippets Groups Projects
Commit 3acdd188 authored by auphelia's avatar auphelia
Browse files

[CustomOp] Change infer datatype function of QuantAvgPool2d

parent 9bbe9761
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ from onnx import TensorProto, helper
import onnxruntime as rt
from finn.custom_op import CustomOp
from finn.core.datatype import DataType
class QuantAvgPool2d(CustomOp):
......@@ -32,8 +33,14 @@ class QuantAvgPool2d(CustomOp):
def infer_node_datatype(self, model):
node = self.onnx_node
# data type stays the same
dtype = model.get_tensor_datatype(node.input[0])
bw = self.get_nodeattr("obits")
if bw in [2,4,8,16,32]:
if self.get_nodeattr("signed") == 0:
dtype = DataType["UINT%d" % bw]
else:
dtype = DataType["INT%d" % bw]
else:
raise Exception("Unsupported output datatype for QuantAvgPool2d")
model.set_tensor_datatype(node.output[0], dtype)
def execute_node(self, context, graph):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment