Skip to content
Snippets Groups Projects
Commit 90cc5159 authored by auphelia's avatar auphelia
Browse files

[QONNX conversion] Update infer quant avg pool 2d

parent c6ee5a6f
No related branches found
No related tags found
No related merge requests found
......@@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node):
# Check if the input of this node already has a FINN datatype
signed = None
inp_dt = model.get_tensor_datatype(node.input[0])
if inp_dt is not None and inp_dt is not DataType["FLOAT32"]:
if inp_dt is not None and inp_dt != "FLOAT32":
signed = inp_dt.signed()
# Go further up the graph, since the datatype inference works top down
# these nodes should either be sign preserving ops or they already have a
......@@ -67,23 +67,27 @@ def _get_signed_from_upstream(model, trunc_node):
)
next_node = next_node[0]
out_dt = model.get_tensor_datatype(next_node.output[0])
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
if out_dt is not None and out_dt != "FLOAT32":
signed = out_dt.signed()
break
# Special cases where the node has an internal or intrinsic datatype.
if next_node.op_type == "MultiThreshold":
mt_inst = getCustomOp(next_node)
mt_inst = getCustomOp(
next_node, onnx_opset_version=9, brevitas_exception=True
)
out_dt = DataType[mt_inst.get_nodeattr("out_dtype")]
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
if out_dt is not None and out_dt != "FLOAT32":
signed = out_dt.signed()
break
if next_node.op_type == "BipolarQuant":
signed = True
break
if next_node.op_type == "Quant":
q_inst = getCustomOp(next_node)
q_inst = getCustomOp(
next_node, onnx_opset_version=9, brevitas_exception=True
)
out_dt = q_inst.get_integer_datatype(model)
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
if out_dt is not None and out_dt != "FLOAT32":
signed = out_dt.signed()
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment