From fcc01b73be2d7f9d9aa6665fd0e41cdb270a4801 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 7 Nov 2019 22:50:06 +0000
Subject: [PATCH] [Transform] add round_thresholds

---
 src/finn/transformation/streamline.py | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 309177e54..5053b23c3 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -363,3 +363,22 @@ def absorb_1bit_mul_into_matmul(model):
                     graph.node.remove(consumer)
                     graph_modified = True
     return (model, graph_modified)
+
+
+def round_thresholds(model):
+    """For MultiThreshold nodes operating on integer inputs, round up
+    thresholds values to the nearest integer."""
+    graph = model.graph
+    graph_modified = False
+    for n in graph.node:
+        if n.op_type == "MultiThreshold":
+            idtype = model.get_tensor_datatype(n.input[0])
+            T = model.get_initializer(n.input[1])
+            Tnew = np.ceil(T)
+            if idtype.is_integer() and (T != Tnew).any():
+                # round up the thresholds to nearest integer
+                model.set_initializer(n.input[1], Tnew)
+                # use same datatype as inputs for thresholds
+                model.set_tensor_datatype(n.input[1], idtype)
+                graph_modified = True
+    return (model, graph_modified)
-- 
GitLab