From d54cb43cb66f4d5b913eefe791325cfc9d3d1c7e Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Thu, 25 Jun 2020 15:29:10 +0100
Subject: [PATCH] [Streamline] Add AbsorbScalarMulIntoTopK transformation

---
 src/finn/transformation/streamline/absorb.py | 37 ++++++++++++++++++++
 1 file changed, 37 insertions(+)

diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index dbcf97361..4266488c7 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -28,11 +28,13 @@
 
 import numpy as np
 from onnx import helper as oh
+import warnings
 
 from finn.core.datatype import DataType
 from finn.transformation import Transformation
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
+from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 
 
@@ -290,3 +292,38 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
         if graph_modified:
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
+
+
+class AbsorbScalarMulIntoTopK(Transformation):
+    """Absorb a mul node into a suceeding topk node if the mul is scalar."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "TopK":
+                prod = model.find_producer(n.input[0])
+                if prod is not None and prod.op_type == "Mul":
+                    prod_input = prod.input[0]
+                    param_name = prod.input[1]
+                    A = model.get_initializer(param_name)
+                    if A is None:
+                        warnings.warn("Param is not constant, skipping")
+                        continue
+                    if all(x == 1 for x in A.shape) and A > 0:
+                        # if the mul is scalar and positive, we can just delete the
+                        # mul node and rewire the top k node. Because the top k node
+                        # works with probabilities and their relation to each other
+                        # the relation doesn't change if every value is multiplied
+                        # with a scalar
+                        graph.node.remove(prod)
+                        n.input[0] = prod_input
+                        # to avoid error the dataype is set to float32
+                        model.set_tensor_datatype(n.input[0], DataType.FLOAT32)
+                        graph_modified = True
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
-- 
GitLab