From 744a43dc5db11df8d44eaf3ae7e08c21ac67d7de Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 5 Apr 2023 14:17:11 +0100
Subject: [PATCH] [Transform] Update check for dt in infer quant avg pool

---
 .../qonnx/infer_quant_avg_pool_2d.py             | 16 ++++++----------
 1 file changed, 6 insertions(+), 10 deletions(-)

diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
index bd3ff1564..d2aaee59a 100644
--- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
+++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
@@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node):
     # Check if the input of this node already has a FINN datatype
     signed = None
     inp_dt = model.get_tensor_datatype(node.input[0])
-    if inp_dt is not None and inp_dt != "FLOAT32":
+    if inp_dt is not None and inp_dt != DataType["FLOAT32"]:
         signed = inp_dt.signed()
     # Go further up the graph, since the datatype inference works top down
     # these nodes should either be sign preserving ops or they already have a
@@ -67,27 +67,23 @@ def _get_signed_from_upstream(model, trunc_node):
                 )
             next_node = next_node[0]
             out_dt = model.get_tensor_datatype(next_node.output[0])
-            if out_dt is not None and out_dt != "FLOAT32":
+            if out_dt is not None and out_dt != DataType["FLOAT32"]:
                 signed = out_dt.signed()
                 break
             # Special cases where the node has an internal or intrinsic datatype.
             if next_node.op_type == "MultiThreshold":
-                mt_inst = getCustomOp(
-                    next_node, onnx_opset_version=9, brevitas_exception=True
-                )
+                mt_inst = getCustomOp(next_node, onnx_opset_version=9)
                 out_dt = DataType[mt_inst.get_nodeattr("out_dtype")]
-                if out_dt is not None and out_dt != "FLOAT32":
+                if out_dt is not None and out_dt != DataType["FLOAT32"]:
                     signed = out_dt.signed()
                     break
             if next_node.op_type == "BipolarQuant":
                 signed = True
                 break
             if next_node.op_type == "Quant":
-                q_inst = getCustomOp(
-                    next_node, onnx_opset_version=9, brevitas_exception=True
-                )
+                q_inst = getCustomOp(next_node, onnx_opset_version=9)
                 out_dt = q_inst.get_integer_datatype(model)
-                if out_dt is not None and out_dt != "FLOAT32":
+                if out_dt is not None and out_dt != DataType["FLOAT32"]:
                     signed = out_dt.signed()
                     break
 
-- 
GitLab