Skip to content
Snippets Groups Projects
Commit 744a43dc authored by auphelia's avatar auphelia
Browse files

[Transform] Update check for dt in infer quant avg pool

parent 90cc5159
No related merge requests found
......@@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node):
# Check if the input of this node already has a FINN datatype
signed = None
inp_dt = model.get_tensor_datatype(node.input[0])
if inp_dt is not None and inp_dt != "FLOAT32":
if inp_dt is not None and inp_dt != DataType["FLOAT32"]:
signed = inp_dt.signed()
# Go further up the graph, since the datatype inference works top down
# these nodes should either be sign preserving ops or they already have a
......@@ -67,27 +67,23 @@ def _get_signed_from_upstream(model, trunc_node):
)
next_node = next_node[0]
out_dt = model.get_tensor_datatype(next_node.output[0])
if out_dt is not None and out_dt != "FLOAT32":
if out_dt is not None and out_dt != DataType["FLOAT32"]:
signed = out_dt.signed()
break
# Special cases where the node has an internal or intrinsic datatype.
if next_node.op_type == "MultiThreshold":
mt_inst = getCustomOp(
next_node, onnx_opset_version=9, brevitas_exception=True
)
mt_inst = getCustomOp(next_node, onnx_opset_version=9)
out_dt = DataType[mt_inst.get_nodeattr("out_dtype")]
if out_dt is not None and out_dt != "FLOAT32":
if out_dt is not None and out_dt != DataType["FLOAT32"]:
signed = out_dt.signed()
break
if next_node.op_type == "BipolarQuant":
signed = True
break
if next_node.op_type == "Quant":
q_inst = getCustomOp(
next_node, onnx_opset_version=9, brevitas_exception=True
)
q_inst = getCustomOp(next_node, onnx_opset_version=9)
out_dt = q_inst.get_integer_datatype(model)
if out_dt is not None and out_dt != "FLOAT32":
if out_dt is not None and out_dt != DataType["FLOAT32"]:
signed = out_dt.signed()
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment