Skip to content
Snippets Groups Projects
Commit 70787a5b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] fix shape problems in add/mul->thres absorption

parent 63c467c8
No related branches found
No related tags found
No related merge requests found
......@@ -248,7 +248,7 @@ def absorb_add_into_multi_threshold(model):
assert T is not None
start_name = n.input[0]
# compute new thresholds and set initializer
Tnew = T - A
Tnew = T - A.reshape(-1, T.shape[1])
model.set_initializer(threshold_name, Tnew)
# wire add input directly to MultiThreshold
consumer.input[0] = start_name
......@@ -277,7 +277,9 @@ def absorb_mul_into_multi_threshold(model):
assert T is not None
start_name = n.input[0]
# compute new thresholds and set initializer
Tnew = T / A
Tnew = T / A.reshape(-1, T.shape[1])
# TODO: need to handle negative A values correctly; produce
# mul sign mask and merge into preceding matmul?
model.set_initializer(threshold_name, Tnew)
# wire add input directly to MultiThreshold
consumer.input[0] = start_name
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment