From 58d28b11e553d510f87b10e8c195732bdd96b1b4 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Fri, 1 Oct 2021 15:39:13 +0100
Subject: [PATCH] Moved removal of FINN datatypes into main QONNX to FINN
 transformation.

---
 .../qonnx/convert_qonnx_to_finn.py             | 18 ++++++++++++++++++
 .../qonnx/qonnx_activation_handlers.py         | 13 -------------
 2 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
index 2a9442c35..5b218f2c3 100644
--- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
+++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
@@ -34,6 +34,7 @@ from finn.transformation.base import Transformation
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler
+from finn.util.basic import get_by_name
 
 
 class ConvertQONNXtoFINN(Transformation):
@@ -51,6 +52,23 @@ class ConvertQONNXtoFINN(Transformation):
         model = model.transform(FoldQuantWeights())
         # Convert activations
         model = model.transform(ConvertQuantActToMultiThreshold())
+        # Infer types again
+        model = model.transform(InferDataTypes())
+
+        # Unset FINN datatypes from MultiThreshold node output tensors to avoid warnings
+        mt_nodes = model.get_nodes_by_op_type("MultiThreshold")
+        qnt_annotations = model._model_proto.graph.quantization_annotation
+        for n in mt_nodes:
+            ret = get_by_name(qnt_annotations, n.output[0], "tensor_name")
+            if ret is not None:
+                ret_dt = get_by_name(
+                    ret.quant_parameter_tensor_names, "finn_datatype", "key"
+                )
+                if ret_dt is not None:
+                    ret_dt.Clear()
+        # ToDo: This might be supported by finn-base in the future,
+        #  by calling the following:
+        # model.set_tensor_datatype(n.output[0], None)
 
         return (model, False)
 
diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
index 9e7ebed93..d5c00c73d 100644
--- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py
+++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
@@ -32,7 +32,6 @@ from onnx import TensorProto, helper
 
 from finn.core.modelwrapper import ModelWrapper
 from finn.custom_op.registry import getCustomOp
-from finn.util.basic import get_by_name
 
 
 class QuantActBaseHandler(ABC):
@@ -156,18 +155,6 @@ class QuantActBaseHandler(ABC):
         graph.node.insert(running_node_index, outp_trans_node)
         running_node_index += 1
 
-        # Unset the FINN datatype
-        qnt_annotations = model._model_proto.graph.quantization_annotation
-        ret = get_by_name(qnt_annotations, n.output[0], "tensor_name")
-        if ret is not None:
-            ret_dt = get_by_name(
-                ret.quant_parameter_tensor_names, "finn_datatype", "key"
-            )
-            if ret_dt is not None:
-                ret_dt.Clear()
-        # ToDo: This should be supported by finn-base, by calling the following:
-        # model.set_tensor_datatype(n.output[0], None)
-
         # Insert Add node
         if adder_bias.shape == (1,):
             adder_bias = adder_bias[0]
-- 
GitLab