From 744a43dc5db11df8d44eaf3ae7e08c21ac67d7de Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 5 Apr 2023 14:17:11 +0100 Subject: [PATCH] [Transform] Update check for dt in infer quant avg pool --- .../qonnx/infer_quant_avg_pool_2d.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py index bd3ff1564..d2aaee59a 100644 --- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py +++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py @@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node): # Check if the input of this node already has a FINN datatype signed = None inp_dt = model.get_tensor_datatype(node.input[0]) - if inp_dt is not None and inp_dt != "FLOAT32": + if inp_dt is not None and inp_dt != DataType["FLOAT32"]: signed = inp_dt.signed() # Go further up the graph, since the datatype inference works top down # these nodes should either be sign preserving ops or they already have a @@ -67,27 +67,23 @@ def _get_signed_from_upstream(model, trunc_node): ) next_node = next_node[0] out_dt = model.get_tensor_datatype(next_node.output[0]) - if out_dt is not None and out_dt != "FLOAT32": + if out_dt is not None and out_dt != DataType["FLOAT32"]: signed = out_dt.signed() break # Special cases where the node has an internal or intrinsic datatype. if next_node.op_type == "MultiThreshold": - mt_inst = getCustomOp( - next_node, onnx_opset_version=9, brevitas_exception=True - ) + mt_inst = getCustomOp(next_node, onnx_opset_version=9) out_dt = DataType[mt_inst.get_nodeattr("out_dtype")] - if out_dt is not None and out_dt != "FLOAT32": + if out_dt is not None and out_dt != DataType["FLOAT32"]: signed = out_dt.signed() break if next_node.op_type == "BipolarQuant": signed = True break if next_node.op_type == "Quant": - q_inst = getCustomOp( - next_node, onnx_opset_version=9, brevitas_exception=True - ) + q_inst = getCustomOp(next_node, onnx_opset_version=9) out_dt = q_inst.get_integer_datatype(model) - if out_dt is not None and out_dt != "FLOAT32": + if out_dt is not None and out_dt != DataType["FLOAT32"]: signed = out_dt.signed() break -- GitLab