Skip to content
Snippets Groups Projects
Commit 90cc5159 authored by auphelia's avatar auphelia
Browse files

[QONNX conversion] Update infer quant avg pool 2d

parent c6ee5a6f
No related branches found
No related tags found
No related merge requests found
...@@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node): ...@@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node):
# Check if the input of this node already has a FINN datatype # Check if the input of this node already has a FINN datatype
signed = None signed = None
inp_dt = model.get_tensor_datatype(node.input[0]) inp_dt = model.get_tensor_datatype(node.input[0])
if inp_dt is not None and inp_dt is not DataType["FLOAT32"]: if inp_dt is not None and inp_dt != "FLOAT32":
signed = inp_dt.signed() signed = inp_dt.signed()
# Go further up the graph, since the datatype inference works top down # Go further up the graph, since the datatype inference works top down
# these nodes should either be sign preserving ops or they already have a # these nodes should either be sign preserving ops or they already have a
...@@ -67,23 +67,27 @@ def _get_signed_from_upstream(model, trunc_node): ...@@ -67,23 +67,27 @@ def _get_signed_from_upstream(model, trunc_node):
) )
next_node = next_node[0] next_node = next_node[0]
out_dt = model.get_tensor_datatype(next_node.output[0]) out_dt = model.get_tensor_datatype(next_node.output[0])
if out_dt is not None and out_dt is not DataType["FLOAT32"]: if out_dt is not None and out_dt != "FLOAT32":
signed = out_dt.signed() signed = out_dt.signed()
break break
# Special cases where the node has an internal or intrinsic datatype. # Special cases where the node has an internal or intrinsic datatype.
if next_node.op_type == "MultiThreshold": if next_node.op_type == "MultiThreshold":
mt_inst = getCustomOp(next_node) mt_inst = getCustomOp(
next_node, onnx_opset_version=9, brevitas_exception=True
)
out_dt = DataType[mt_inst.get_nodeattr("out_dtype")] out_dt = DataType[mt_inst.get_nodeattr("out_dtype")]
if out_dt is not None and out_dt is not DataType["FLOAT32"]: if out_dt is not None and out_dt != "FLOAT32":
signed = out_dt.signed() signed = out_dt.signed()
break break
if next_node.op_type == "BipolarQuant": if next_node.op_type == "BipolarQuant":
signed = True signed = True
break break
if next_node.op_type == "Quant": if next_node.op_type == "Quant":
q_inst = getCustomOp(next_node) q_inst = getCustomOp(
next_node, onnx_opset_version=9, brevitas_exception=True
)
out_dt = q_inst.get_integer_datatype(model) out_dt = q_inst.get_integer_datatype(model)
if out_dt is not None and out_dt is not DataType["FLOAT32"]: if out_dt is not None and out_dt != "FLOAT32":
signed = out_dt.signed() signed = out_dt.signed()
break break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment