Skip to content
Snippets Groups Projects
Commit fcc01b73 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add round_thresholds

parent 43b1ea05
No related branches found
No related tags found
No related merge requests found
......@@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model):
graph.node.remove(consumer)
graph_modified = True
return (model, graph_modified)
def round_thresholds(model):
"""For MultiThreshold nodes operating on integer inputs, round up
thresholds values to the nearest integer."""
graph = model.graph
graph_modified = False
for n in graph.node:
if n.op_type == "MultiThreshold":
idtype = model.get_tensor_datatype(n.input[0])
T = model.get_initializer(n.input[1])
Tnew = np.ceil(T)
if idtype.is_integer() and (T != Tnew).any():
# round up the thresholds to nearest integer
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment