diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py index bd3ff156455c68a76792c768ba6c2a64a14941d6..d2aaee59a4767f36d2e948fa97b1deecb7f365ac 100644 --- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py +++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py @@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node): # Check if the input of this node already has a FINN datatype signed = None inp_dt = model.get_tensor_datatype(node.input[0]) - if inp_dt is not None and inp_dt != "FLOAT32": + if inp_dt is not None and inp_dt != DataType["FLOAT32"]: signed = inp_dt.signed() # Go further up the graph, since the datatype inference works top down # these nodes should either be sign preserving ops or they already have a @@ -67,27 +67,23 @@ def _get_signed_from_upstream(model, trunc_node): ) next_node = next_node[0] out_dt = model.get_tensor_datatype(next_node.output[0]) - if out_dt is not None and out_dt != "FLOAT32": + if out_dt is not None and out_dt != DataType["FLOAT32"]: signed = out_dt.signed() break # Special cases where the node has an internal or intrinsic datatype. if next_node.op_type == "MultiThreshold": - mt_inst = getCustomOp( - next_node, onnx_opset_version=9, brevitas_exception=True - ) + mt_inst = getCustomOp(next_node, onnx_opset_version=9) out_dt = DataType[mt_inst.get_nodeattr("out_dtype")] - if out_dt is not None and out_dt != "FLOAT32": + if out_dt is not None and out_dt != DataType["FLOAT32"]: signed = out_dt.signed() break if next_node.op_type == "BipolarQuant": signed = True break if next_node.op_type == "Quant": - q_inst = getCustomOp( - next_node, onnx_opset_version=9, brevitas_exception=True - ) + q_inst = getCustomOp(next_node, onnx_opset_version=9) out_dt = q_inst.get_integer_datatype(model) - if out_dt is not None and out_dt != "FLOAT32": + if out_dt is not None and out_dt != DataType["FLOAT32"]: signed = out_dt.signed() break