diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
index bd3ff156455c68a76792c768ba6c2a64a14941d6..d2aaee59a4767f36d2e948fa97b1deecb7f365ac 100644
--- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
+++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
@@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node):
     # Check if the input of this node already has a FINN datatype
     signed = None
     inp_dt = model.get_tensor_datatype(node.input[0])
-    if inp_dt is not None and inp_dt != "FLOAT32":
+    if inp_dt is not None and inp_dt != DataType["FLOAT32"]:
         signed = inp_dt.signed()
     # Go further up the graph, since the datatype inference works top down
     # these nodes should either be sign preserving ops or they already have a
@@ -67,27 +67,23 @@ def _get_signed_from_upstream(model, trunc_node):
                 )
             next_node = next_node[0]
             out_dt = model.get_tensor_datatype(next_node.output[0])
-            if out_dt is not None and out_dt != "FLOAT32":
+            if out_dt is not None and out_dt != DataType["FLOAT32"]:
                 signed = out_dt.signed()
                 break
             # Special cases where the node has an internal or intrinsic datatype.
             if next_node.op_type == "MultiThreshold":
-                mt_inst = getCustomOp(
-                    next_node, onnx_opset_version=9, brevitas_exception=True
-                )
+                mt_inst = getCustomOp(next_node, onnx_opset_version=9)
                 out_dt = DataType[mt_inst.get_nodeattr("out_dtype")]
-                if out_dt is not None and out_dt != "FLOAT32":
+                if out_dt is not None and out_dt != DataType["FLOAT32"]:
                     signed = out_dt.signed()
                     break
             if next_node.op_type == "BipolarQuant":
                 signed = True
                 break
             if next_node.op_type == "Quant":
-                q_inst = getCustomOp(
-                    next_node, onnx_opset_version=9, brevitas_exception=True
-                )
+                q_inst = getCustomOp(next_node, onnx_opset_version=9)
                 out_dt = q_inst.get_integer_datatype(model)
-                if out_dt is not None and out_dt != "FLOAT32":
+                if out_dt is not None and out_dt != DataType["FLOAT32"]:
                     signed = out_dt.signed()
                     break