From fcc01b73be2d7f9d9aa6665fd0e41cdb270a4801 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 7 Nov 2019 22:50:06 +0000 Subject: [PATCH] [Transform] add round_thresholds --- src/finn/transformation/streamline.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 309177e54..5053b23c3 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model): graph.node.remove(consumer) graph_modified = True return (model, graph_modified) + + +def round_thresholds(model): + """For MultiThreshold nodes operating on integer inputs, round up + thresholds values to the nearest integer.""" + graph = model.graph + graph_modified = False + for n in graph.node: + if n.op_type == "MultiThreshold": + idtype = model.get_tensor_datatype(n.input[0]) + T = model.get_initializer(n.input[1]) + Tnew = np.ceil(T) + if idtype.is_integer() and (T != Tnew).any(): + # round up the thresholds to nearest integer + model.set_initializer(n.input[1], Tnew) + # use same datatype as inputs for thresholds + model.set_tensor_datatype(n.input[1], idtype) + graph_modified = True + return (model, graph_modified) -- GitLab