Skip to content
Snippets Groups Projects
Commit 28721620 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] restrict absorb_mul_into_multi_threshold

dim restriction: mul must be scalar or 1d
val restriction: mul value must be non-negative
parent 2b51e014
No related branches found
No related tags found
No related merge requests found
......@@ -260,30 +260,34 @@ def absorb_add_into_multi_threshold(model):
def absorb_mul_into_multi_threshold(model):
"""Absorb preceding Mul ops into MultiThreshold by updating the threshold
values."""
values. Only *positive* scalar/1D vectors can be absorbed."""
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Mul":
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None
is_signed = (A < 0).any()
is_scalar = np.prod(A.shape) == 1
is_1d = len(A.shape) == 2 and A.shape[0] == 1
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
mul_weight_name = n.input[1]
threshold_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
T = model.get_initializer(threshold_name)
assert A is not None
assert T is not None
start_name = n.input[0]
# compute new thresholds and set initializer
Tnew = T / A.reshape(-1, T.shape[1])
# TODO: need to handle negative A values correctly; produce
# mul sign mask and merge into preceding matmul?
model.set_initializer(threshold_name, Tnew)
# wire add input directly to MultiThreshold
consumer.input[0] = start_name
# remove the mul node
graph.node.remove(n)
graph_modified = True
if not is_signed and (is_1d or is_scalar):
threshold_name = consumer.input[1]
T = model.get_initializer(threshold_name)
assert T is not None
start_name = n.input[0]
# compute new thresholds and set initializer
Tnew = T / A.reshape(-1, T.shape[1])
# TODO: need to handle negative A values correctly; produce
# mul sign mask and merge into preceding matmul?
model.set_initializer(threshold_name, Tnew)
# wire add input directly to MultiThreshold
consumer.input[0] = start_name
# remove the mul node
graph.node.remove(n)
graph_modified = True
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment