diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py index 2a9442c35ac4c6da483a671a58c91461a56d301f..5b218f2c38592afff3b790395154454e563028bb 100644 --- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py +++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py @@ -34,6 +34,7 @@ from finn.transformation.base import Transformation from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_shapes import InferShapes from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler +from finn.util.basic import get_by_name class ConvertQONNXtoFINN(Transformation): @@ -51,6 +52,23 @@ class ConvertQONNXtoFINN(Transformation): model = model.transform(FoldQuantWeights()) # Convert activations model = model.transform(ConvertQuantActToMultiThreshold()) + # Infer types again + model = model.transform(InferDataTypes()) + + # Unset FINN datatypes from MultiThreshold node output tensors to avoid warnings + mt_nodes = model.get_nodes_by_op_type("MultiThreshold") + qnt_annotations = model._model_proto.graph.quantization_annotation + for n in mt_nodes: + ret = get_by_name(qnt_annotations, n.output[0], "tensor_name") + if ret is not None: + ret_dt = get_by_name( + ret.quant_parameter_tensor_names, "finn_datatype", "key" + ) + if ret_dt is not None: + ret_dt.Clear() + # ToDo: This might be supported by finn-base in the future, + # by calling the following: + # model.set_tensor_datatype(n.output[0], None) return (model, False) diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 9e7ebed9359427c1d7ba32e36671c64244ce1541..d5c00c73dab68a479a1f0f7cce7e7395d4fb6bd4 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -32,7 +32,6 @@ from onnx import TensorProto, helper from finn.core.modelwrapper import ModelWrapper from finn.custom_op.registry import getCustomOp -from finn.util.basic import get_by_name class QuantActBaseHandler(ABC): @@ -156,18 +155,6 @@ class QuantActBaseHandler(ABC): graph.node.insert(running_node_index, outp_trans_node) running_node_index += 1 - # Unset the FINN datatype - qnt_annotations = model._model_proto.graph.quantization_annotation - ret = get_by_name(qnt_annotations, n.output[0], "tensor_name") - if ret is not None: - ret_dt = get_by_name( - ret.quant_parameter_tensor_names, "finn_datatype", "key" - ) - if ret_dt is not None: - ret_dt.Clear() - # ToDo: This should be supported by finn-base, by calling the following: - # model.set_tensor_datatype(n.output[0], None) - # Insert Add node if adder_bias.shape == (1,): adder_bias = adder_bias[0]