diff --git a/src/finn/transformation/convert_qonnx_to_finn.py b/src/finn/transformation/convert_qonnx_to_finn.py index eea00b7fa45c1efd590c065a38589c6155e42845..814c7f05950ea9676a72f8c585d2b73d31284768 100644 --- a/src/finn/transformation/convert_qonnx_to_finn.py +++ b/src/finn/transformation/convert_qonnx_to_finn.py @@ -37,6 +37,7 @@ from finn.custom_op.registry import getCustomOp from finn.transformation.base import Transformation from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_shapes import InferShapes +from finn.util.basic import get_by_name class ConvertQONNXtoFINN(Transformation): @@ -54,8 +55,6 @@ class ConvertQONNXtoFINN(Transformation): model = model.transform(FoldQuantWeights()) # Convert activations model = model.transform(ConvertQuantActToMultiThreshold()) - # Some datatypes have changed - model = model.transform(InferDataTypes()) return (model, False) @@ -344,6 +343,13 @@ class QuantActBaseHandler(ABC): graph.node.insert(running_node_index, outp_trans_node) running_node_index += 1 + # Unset the FINN datatype + qnt_annotations = model._model_proto.graph.quantization_annotation + ret = get_by_name(qnt_annotations, n.output[0], "tensor_name") + ret.Clear() + # ToDo: This should be supported by finn-base, by calling the following: + # model.set_tensor_datatype(n.output[0], None) + # Insert Add node if adder_bias.shape == (1,): adder_bias = adder_bias[0]