Skip to content
Snippets Groups Projects
Commit 985b14ed authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add absorb mul/add into MultiThreshold

parent 2564cd42
No related branches found
No related tags found
No related merge requests found
......@@ -193,3 +193,61 @@ def move_scalar_add_past_matmul(model):
graph_modified = True
model = model.transform_single(si.infer_shapes)
return (model, graph_modified)
def absorb_add_into_multi_threshold(model):
"""Absorb preceding Add ops into MultiThreshold by updating the threshold
values."""
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Add":
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
add_weight_name = n.input[1]
threshold_name = consumer.input[1]
A = model.get_initializer(add_weight_name)
T = model.get_initializer(threshold_name)
assert A is not None
assert T is not None
start_name = n.input[0]
# compute new thresholds and set initializer
Tnew = T - A
model.set_initializer(threshold_name, Tnew)
# wire add input directly to MultiThreshold
consumer.input[0] = start_name
# remove the add node
graph.node.remove(n)
graph_modified = True
return (model, graph_modified)
def absorb_mul_into_multi_threshold(model):
"""Absorb preceding Mul ops into MultiThreshold by updating the threshold
values."""
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Mul":
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
mul_weight_name = n.input[1]
threshold_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
T = model.get_initializer(threshold_name)
assert A is not None
assert T is not None
start_name = n.input[0]
# compute new thresholds and set initializer
Tnew = T / A
model.set_initializer(threshold_name, Tnew)
# wire add input directly to MultiThreshold
consumer.input[0] = start_name
# remove the mul node
graph.node.remove(n)
graph_modified = True
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment