Skip to content
Snippets Groups Projects
Commit 28721620 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] restrict absorb_mul_into_multi_threshold

dim restriction: mul must be scalar or 1d
val restriction: mul value must be non-negative
parent 2b51e014
No related branches found
No related tags found
No related merge requests found
...@@ -260,30 +260,34 @@ def absorb_add_into_multi_threshold(model): ...@@ -260,30 +260,34 @@ def absorb_add_into_multi_threshold(model):
def absorb_mul_into_multi_threshold(model): def absorb_mul_into_multi_threshold(model):
"""Absorb preceding Mul ops into MultiThreshold by updating the threshold """Absorb preceding Mul ops into MultiThreshold by updating the threshold
values.""" values. Only *positive* scalar/1D vectors can be absorbed."""
graph = model.graph graph = model.graph
node_ind = 0 node_ind = 0
graph_modified = False graph_modified = False
for n in graph.node: for n in graph.node:
node_ind += 1 node_ind += 1
if n.op_type == "Mul": if n.op_type == "Mul":
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None
is_signed = (A < 0).any()
is_scalar = np.prod(A.shape) == 1
is_1d = len(A.shape) == 2 and A.shape[0] == 1
consumer = model.find_consumer(n.output[0]) consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold": if consumer is not None and consumer.op_type == "MultiThreshold":
mul_weight_name = n.input[1] if not is_signed and (is_1d or is_scalar):
threshold_name = consumer.input[1] threshold_name = consumer.input[1]
A = model.get_initializer(mul_weight_name) T = model.get_initializer(threshold_name)
T = model.get_initializer(threshold_name) assert T is not None
assert A is not None start_name = n.input[0]
assert T is not None # compute new thresholds and set initializer
start_name = n.input[0] Tnew = T / A.reshape(-1, T.shape[1])
# compute new thresholds and set initializer # TODO: need to handle negative A values correctly; produce
Tnew = T / A.reshape(-1, T.shape[1]) # mul sign mask and merge into preceding matmul?
# TODO: need to handle negative A values correctly; produce model.set_initializer(threshold_name, Tnew)
# mul sign mask and merge into preceding matmul? # wire add input directly to MultiThreshold
model.set_initializer(threshold_name, Tnew) consumer.input[0] = start_name
# wire add input directly to MultiThreshold # remove the mul node
consumer.input[0] = start_name graph.node.remove(n)
# remove the mul node graph_modified = True
graph.node.remove(n)
graph_modified = True
return (model, graph_modified) return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment