From 5ba522dbe8ee1265132f65493d777ff9d4bdc9f4 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Thu, 2 Jul 2020 09:40:25 +0100
Subject: [PATCH] Merge upsteam changes in absorb.py

---
 src/finn/transformation/streamline/absorb.py | 165 +++++++++++++++++++
 1 file changed, 165 insertions(+)

diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 82cc6b42d..65e67721c 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -28,14 +28,81 @@
 
 import numpy as np
 from onnx import helper as oh
+import warnings
 
 from finn.core.datatype import DataType
+import finn.core.data_layout as DataLayout
 from finn.transformation import Transformation
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
+from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 
 
+class AbsorbSignBiasIntoMultiThreshold(Transformation):
+    """Absorb scalar bias originating from signed int export back into
+    MultiThreshold and re-evaluate the output datatype."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            # search for (MultiThreshold, Add) pair
+            node_ind += 1
+            if (
+                n.op_type == "MultiThreshold"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "Add":
+                    mt_node = n
+                    add_node = consumer
+                    threshold_name = mt_node.input[1]
+                    add_weight_name = add_node.input[1]
+                    T = model.get_initializer(threshold_name)
+                    A = model.get_initializer(add_weight_name)
+                    if (A is None) or (T is None):
+                        warnings.warn("Threshold or add bias not constant, skipping")
+                        continue
+                    end_name = add_node.output[0]
+                    # we can only absorb scalar adds
+                    is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
+                    if not is_scalar:
+                        continue
+                    bias = A.flatten()[0]
+                    # set MultiThreshold bias property
+                    mt_inst = getCustomOp(mt_node)
+                    bias += mt_inst.get_nodeattr("out_bias")
+                    mt_inst.set_nodeattr("out_bias", bias)
+                    graph_modified = True
+                    # compute new DataType for MultiThreshold output
+                    steps = T.shape[-1]
+                    new_min = bias
+                    new_max = steps + bias
+                    odt = DataType.get_smallest_possible(steps).name.replace(
+                        "UINT", "INT"
+                    )
+                    odt = DataType[odt]
+                    assert odt.allowed(new_max) and odt.allowed(
+                        new_min
+                    ), """Could
+                    not compute new MultiThreshold DataType (min = %d max = %d)""" % (
+                        new_min,
+                        new_max,
+                    )
+                    mt_inst.set_nodeattr("out_dtype", odt.name)
+                    # remove Add node, rewire MultiThreshold
+                    graph.node.remove(add_node)
+                    mt_node.output[0] = end_name
+                    # set datatype
+                    model.set_tensor_datatype(end_name, odt)
+        if graph_modified:
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+
+
 class AbsorbAddIntoMultiThreshold(Transformation):
     """Absorb preceding Add ops into MultiThreshold by updating the threshold
     values. Only scalar/1D add vectors can be absorbed."""
@@ -292,6 +359,104 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
         return (model, graph_modified)
 
 
+class AbsorbTransposeIntoFlatten(Transformation):
+    """Absorb transpose node into succeeding flatten node, if H=W=1 and the first
+    dimension stays the same. Can also be applied if flatten is implemented implicitly
+    by a reshape node with shape [1, -1] and the first input dimension is 1"""
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "Reshape"
+                and (model.get_initializer(n.input[1]) == [1, -1]).all()
+            ) or n.op_type == "Flatten":
+                prod = model.find_producer(n.input[0])
+                if (
+                    prod is not None
+                    and prod.op_type == "Transpose"
+                    # we ensure that the first dimension is not changed from the
+                    # transpose operation
+                    and get_by_name(prod.attribute, "perm").ints[0] == 0
+                ):
+                    data_layout = model.get_tensor_layout(prod.input[0])
+                    # check for the data layout to interpret input shape correctly
+                    if data_layout is None:
+                        warnings.warn(
+                            """Data layout for input tensor of Transpose node is not set.
+                                To use AbsorbTransposeIntoFlatten transformation
+                                please set tensor data layout."""
+                        )
+                        continue
+                    elif data_layout == DataLayout.NCHW:
+                        (b, c, h, w) = model.get_tensor_shape(prod.input[0])
+                        # if h=w=1 the transposition can be absorbed, otherwise
+                        # the absorption would lead to an error in the behavior
+                        if h != 1 or w != 1:
+                            continue
+                        # the flatten node from onnx keeps by default the first
+                        # dim and flattens the rest, that is why this transformation
+                        # can only work with b != 1 if the model contains already a
+                        # flatten node and not a reshape node with shape = [1, -1].
+                        # If the first  dim of the input tensor is not 1, flatten and
+                        # reshape (with shape = [1, -1]) would lead to different results
+                        if n.op_type == "Reshape" and b != 1:
+                            continue
+                    elif data_layout == DataLayout.NHWC:
+                        (b, h, w, c) = model.get_tensor_shape(prod.input[0])
+                        if h != 1 or w != 1:
+                            continue
+                        if n.op_type == "Reshape" and b != 1:
+                            continue
+                    # create single flatten node and remove obsolete nodes
+                    node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]])
+                    graph.node.remove(n)
+                    graph.node.remove(prod)
+                    graph.node.insert(node_ind, node)
+                    graph_modified = True
+        if graph_modified:
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+
+
+class AbsorbScalarMulIntoTopK(Transformation):
+    """Absorb a mul node into a suceeding topk node if the mul is scalar."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "TopK":
+                prod = model.find_producer(n.input[0])
+                if prod is not None and prod.op_type == "Mul":
+                    prod_input = prod.input[0]
+                    param_name = prod.input[1]
+                    A = model.get_initializer(param_name)
+                    if A is None:
+                        warnings.warn("Param is not constant, skipping")
+                        continue
+                    if all(x == 1 for x in A.shape) and A > 0:
+                        # if the mul is scalar and positive, we can just delete the
+                        # mul node and rewire the top k node. Because the top k node
+                        # works with probabilities and their relation to each other
+                        # the relation doesn't change if every value is multiplied
+                        # with a scalar
+                        graph.node.remove(prod)
+                        n.input[0] = prod_input
+                        # to avoid error the dataype is set to float32
+                        model.set_tensor_datatype(n.input[0], DataType.FLOAT32)
+                        graph_modified = True
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+
+
 class AbsorbConsecutiveTransposes(Transformation):
     """Remove (Transpose -> Transpose) patterns when the input and output
     of the pattern have the same layout."""
-- 
GitLab