From 70787a5b8b736b0050e9a3278c22588e0369cbf5 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sun, 3 Nov 2019 23:25:28 +0000
Subject: [PATCH] [Transform] fix shape problems in add/mul->thres absorption

---
 src/finn/transformation/streamline.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 845036f81..04c7f05b3 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -248,7 +248,7 @@ def absorb_add_into_multi_threshold(model):
                 assert T is not None
                 start_name = n.input[0]
                 # compute new thresholds and set initializer
-                Tnew = T - A
+                Tnew = T - A.reshape(-1, T.shape[1])
                 model.set_initializer(threshold_name, Tnew)
                 # wire add input directly to MultiThreshold
                 consumer.input[0] = start_name
@@ -277,7 +277,9 @@ def absorb_mul_into_multi_threshold(model):
                 assert T is not None
                 start_name = n.input[0]
                 # compute new thresholds and set initializer
-                Tnew = T / A
+                Tnew = T / A.reshape(-1, T.shape[1])
+                # TODO: need to handle negative A values correctly; produce
+                # mul sign mask and merge into preceding matmul?
                 model.set_initializer(threshold_name, Tnew)
                 # wire add input directly to MultiThreshold
                 consumer.input[0] = start_name
-- 
GitLab