Skip to content
Snippets Groups Projects
Commit 58d28b11 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Moved removal of FINN datatypes into main QONNX to FINN transformation.

parent 18547a2c
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,7 @@ from finn.transformation.base import Transformation
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler
from finn.util.basic import get_by_name
class ConvertQONNXtoFINN(Transformation):
......@@ -51,6 +52,23 @@ class ConvertQONNXtoFINN(Transformation):
model = model.transform(FoldQuantWeights())
# Convert activations
model = model.transform(ConvertQuantActToMultiThreshold())
# Infer types again
model = model.transform(InferDataTypes())
# Unset FINN datatypes from MultiThreshold node output tensors to avoid warnings
mt_nodes = model.get_nodes_by_op_type("MultiThreshold")
qnt_annotations = model._model_proto.graph.quantization_annotation
for n in mt_nodes:
ret = get_by_name(qnt_annotations, n.output[0], "tensor_name")
if ret is not None:
ret_dt = get_by_name(
ret.quant_parameter_tensor_names, "finn_datatype", "key"
)
if ret_dt is not None:
ret_dt.Clear()
# ToDo: This might be supported by finn-base in the future,
# by calling the following:
# model.set_tensor_datatype(n.output[0], None)
return (model, False)
......
......@@ -32,7 +32,6 @@ from onnx import TensorProto, helper
from finn.core.modelwrapper import ModelWrapper
from finn.custom_op.registry import getCustomOp
from finn.util.basic import get_by_name
class QuantActBaseHandler(ABC):
......@@ -156,18 +155,6 @@ class QuantActBaseHandler(ABC):
graph.node.insert(running_node_index, outp_trans_node)
running_node_index += 1
# Unset the FINN datatype
qnt_annotations = model._model_proto.graph.quantization_annotation
ret = get_by_name(qnt_annotations, n.output[0], "tensor_name")
if ret is not None:
ret_dt = get_by_name(
ret.quant_parameter_tensor_names, "finn_datatype", "key"
)
if ret_dt is not None:
ret_dt.Clear()
# ToDo: This should be supported by finn-base, by calling the following:
# model.set_tensor_datatype(n.output[0], None)
# Insert Add node
if adder_bias.shape == (1,):
adder_bias = adder_bias[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment