diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 309177e548be490a7de05aad6006b838b207f3bc..5053b23c3e09bc5f32807cdf16a997c4647dd165 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model): graph.node.remove(consumer) graph_modified = True return (model, graph_modified) + + +def round_thresholds(model): + """For MultiThreshold nodes operating on integer inputs, round up + thresholds values to the nearest integer.""" + graph = model.graph + graph_modified = False + for n in graph.node: + if n.op_type == "MultiThreshold": + idtype = model.get_tensor_datatype(n.input[0]) + T = model.get_initializer(n.input[1]) + Tnew = np.ceil(T) + if idtype.is_integer() and (T != Tnew).any(): + # round up the thresholds to nearest integer + model.set_initializer(n.input[1], Tnew) + # use same datatype as inputs for thresholds + model.set_tensor_datatype(n.input[1], idtype) + graph_modified = True + return (model, graph_modified)