From b5bcd32035c01098b5c3107c3223eecedd56b5a6 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Thu, 7 Oct 2021 16:56:44 +0100
Subject: [PATCH] Catch unsupported constant folding for SCALED datatypes.

---
 .../qonnx/convert_qonnx_to_finn.py            | 138 +++++++++---------
 1 file changed, 70 insertions(+), 68 deletions(-)

diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
index ec222ea66..0c606e36c 100644
--- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
+++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
@@ -119,92 +119,94 @@ class FoldQuantWeights(Transformation):
                             target_node = target_node[0]
 
                         # Check next operator type
-                        # ToDo: CHECK what this is merged into:
-                        #  Conv, MatMul and Mul nodes need
-                        #  only the extra multiplication behind
-                        #  BUT Add needs an extra Div in front!
                         mul_like_nodes = ["Mul", "Div", "Conv", "MatMul"]
                         add_like_nodes = ["Add", "Sub"]
+                        all_supported_ops = mul_like_nodes.copy()
+                        all_supported_ops.extend(add_like_nodes)
 
-                        if (
-                            target_node.op_type in mul_like_nodes
-                            or target_node.op_type in add_like_nodes
-                        ):
-                            # Move the scale factor behind the next operator
-                            scale = model.get_initializer(n.input[1])
-                            model.set_initializer(node_out, q_node_output / scale)
-                            new_dtype = DataType[dtype.name.replace("SCALED", "")]
-                            model.set_tensor_datatype(node_out, new_dtype)
+                        if target_node.op_type not in all_supported_ops:
+                            raise ValueError(
+                                f"Can't constant fold Quant weight node "
+                                f"into node type {target_node.op_type} "
+                                f"at node: {target_node}."
+                            )
+
+                        # For buth mul and Add:
+                        # Move the scale factor behind the next operator
+                        scale = model.get_initializer(n.input[1])
+                        model.set_initializer(node_out, q_node_output / scale)
+                        new_dtype = DataType[dtype.name.replace("SCALED", "")]
+                        model.set_tensor_datatype(node_out, new_dtype)
+
+                        if target_node.op_type == "Conv" and len(scale.shape) > 0:
+                            bias_shape = [1] * len(scale.shape)
+                            bias_shape[1] = -1
+                            scale = scale.reshape(bias_shape)
+
+                        if scale.shape == (1,):
+                            scale = scale[0]
+                            mul_shape = tuple()
+                        else:
+                            mul_shape = scale.shape
+                        mul_tensor = helper.make_tensor_value_info(
+                            model.make_new_valueinfo_name(),
+                            TensorProto.FLOAT,
+                            mul_shape,
+                        )
+                        graph.value_info.append(mul_tensor)
+                        model.set_initializer(mul_tensor.name, scale)
+
+                        successor = model.find_consumers(node_out)
+                        if successor is None:
+                            raise RuntimeError(
+                                "Can only constant fold scaled Quant weights "
+                                "if a successor exists."
+                            )
+                        successor = successor[0]
+                        succ_output_name = successor.output[0]
+
+                        output_shape = model.get_tensor_shape(successor.output[0])
+                        act_mul_tensor = helper.make_tensor_value_info(
+                            model.make_new_valueinfo_name(),
+                            TensorProto.FLOAT,
+                            output_shape,
+                        )
+                        graph.value_info.append(act_mul_tensor)
+                        successor.output[0] = act_mul_tensor.name
 
-                            if target_node.op_type == "Conv" and len(scale.shape) > 0:
-                                bias_shape = [1] * len(scale.shape)
-                                bias_shape[1] = -1
-                                scale = scale.reshape(bias_shape)
+                        mul_node = helper.make_node(
+                            "Mul",
+                            [act_mul_tensor.name, mul_tensor.name],
+                            [succ_output_name],
+                        )
+                        graph.node.insert(node_ind, mul_node)
 
-                            if scale.shape == (1,):
-                                scale = scale[0]
-                                mul_shape = tuple()
-                            else:
-                                mul_shape = scale.shape
-                            mul_tensor = helper.make_tensor_value_info(
+                        if target_node.op_type in add_like_nodes:
+                            # Move the scale factor behind also in-front of
+                            # the next operator
+                            div_tensor = helper.make_tensor_value_info(
                                 model.make_new_valueinfo_name(),
                                 TensorProto.FLOAT,
                                 mul_shape,
                             )
-                            graph.value_info.append(mul_tensor)
-                            model.set_initializer(mul_tensor.name, scale)
+                            graph.value_info.append(div_tensor)
+                            model.set_initializer(div_tensor.name, scale)
 
-                            successor = model.find_consumers(node_out)
-                            if successor is None:
-                                raise RuntimeError(
-                                    "Can only constant fold scaled Quant weights "
-                                    "if a successor exists."
-                                )
-                            successor = successor[0]
-                            succ_output_name = successor.output[0]
-
-                            output_shape = model.get_tensor_shape(successor.output[0])
+                            succ_input_name = successor.input[0]
                             act_mul_tensor = helper.make_tensor_value_info(
                                 model.make_new_valueinfo_name(),
                                 TensorProto.FLOAT,
                                 output_shape,
                             )
                             graph.value_info.append(act_mul_tensor)
-                            successor.output[0] = act_mul_tensor.name
+                            successor.input[0] = act_mul_tensor.name
 
-                            mul_node = helper.make_node(
-                                "Mul",
-                                [act_mul_tensor.name, mul_tensor.name],
-                                [succ_output_name],
+                            div_node = helper.make_node(
+                                "Div",
+                                [succ_input_name, div_tensor.name],
+                                [act_mul_tensor.name],
                             )
-                            graph.node.insert(node_ind, mul_node)
-
-                            if target_node.op_type in add_like_nodes:
-                                # Move the scale factor behind also in-front of
-                                # the next operator
-                                div_tensor = helper.make_tensor_value_info(
-                                    model.make_new_valueinfo_name(),
-                                    TensorProto.FLOAT,
-                                    mul_shape,
-                                )
-                                graph.value_info.append(div_tensor)
-                                model.set_initializer(div_tensor.name, scale)
-
-                                succ_input_name = successor.input[0]
-                                act_mul_tensor = helper.make_tensor_value_info(
-                                    model.make_new_valueinfo_name(),
-                                    TensorProto.FLOAT,
-                                    output_shape,
-                                )
-                                graph.value_info.append(act_mul_tensor)
-                                successor.input[0] = act_mul_tensor.name
-
-                                div_node = helper.make_node(
-                                    "Div",
-                                    [succ_input_name, div_tensor.name],
-                                    [act_mul_tensor.name],
-                                )
-                                graph.node.insert(node_ind, div_node)
+                            graph.node.insert(node_ind, div_node)
 
                     else:
                         # use the execution result as an initializer
-- 
GitLab