From 28721620f08ade4e5b672c3731cff55dbceeac7d Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Mon, 4 Nov 2019 17:38:33 +0000
Subject: [PATCH] [Transform] restrict absorb_mul_into_multi_threshold

dim restriction: mul must be scalar or 1d
val restriction: mul value must be non-negative
---
 src/finn/transformation/streamline.py | 40 +++++++++++++++------------
 1 file changed, 22 insertions(+), 18 deletions(-)

diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 04c7f05b3..34e926e5b 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -260,30 +260,34 @@ def absorb_add_into_multi_threshold(model):
 
 def absorb_mul_into_multi_threshold(model):
     """Absorb preceding Mul ops into MultiThreshold by updating the threshold
-    values."""
+    values. Only *positive* scalar/1D vectors can be absorbed."""
     graph = model.graph
     node_ind = 0
     graph_modified = False
     for n in graph.node:
         node_ind += 1
         if n.op_type == "Mul":
+            mul_weight_name = n.input[1]
+            A = model.get_initializer(mul_weight_name)
+            assert A is not None
+            is_signed = (A < 0).any()
+            is_scalar = np.prod(A.shape) == 1
+            is_1d = len(A.shape) == 2 and A.shape[0] == 1
             consumer = model.find_consumer(n.output[0])
             if consumer is not None and consumer.op_type == "MultiThreshold":
-                mul_weight_name = n.input[1]
-                threshold_name = consumer.input[1]
-                A = model.get_initializer(mul_weight_name)
-                T = model.get_initializer(threshold_name)
-                assert A is not None
-                assert T is not None
-                start_name = n.input[0]
-                # compute new thresholds and set initializer
-                Tnew = T / A.reshape(-1, T.shape[1])
-                # TODO: need to handle negative A values correctly; produce
-                # mul sign mask and merge into preceding matmul?
-                model.set_initializer(threshold_name, Tnew)
-                # wire add input directly to MultiThreshold
-                consumer.input[0] = start_name
-                # remove the mul node
-                graph.node.remove(n)
-                graph_modified = True
+                if not is_signed and (is_1d or is_scalar):
+                    threshold_name = consumer.input[1]
+                    T = model.get_initializer(threshold_name)
+                    assert T is not None
+                    start_name = n.input[0]
+                    # compute new thresholds and set initializer
+                    Tnew = T / A.reshape(-1, T.shape[1])
+                    # TODO: need to handle negative A values correctly; produce
+                    # mul sign mask and merge into preceding matmul?
+                    model.set_initializer(threshold_name, Tnew)
+                    # wire add input directly to MultiThreshold
+                    consumer.input[0] = start_name
+                    # remove the mul node
+                    graph.node.remove(n)
+                    graph_modified = True
     return (model, graph_modified)
-- 
GitLab