From ba43d9c9f3851bc4fdfc47d917dac49159dadfaf Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 1 Nov 2019 14:25:49 +0000
Subject: [PATCH] [Transform] add convert_sign_to_thres

---
 src/finn/transformation/streamline.py | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index f54b78f04..462005caa 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -2,6 +2,24 @@ import numpy as np
 from onnx import helper as oh
 
 import finn.transformation.infer_shapes as si
+from finn.core.datatype import DataType
+
+
+def convert_sign_to_thres(model):
+    """Convert Sign node instances to MultiThreshold with threshold at 0."""
+    graph = model.graph
+    graph_modified = False
+    for n in graph.node:
+        if n.op_type == "Sign":
+            n.op_type = "MultiThreshold"
+            thres_param_name = model.make_new_valueinfo_name()
+            thres_param = np.asarray([[0]], dtype=np.float32)
+            n.input.append(thres_param_name)
+            model.set_initializer(thres_param_name, thres_param)
+            # mark output tensor as bipolar
+            model.set_tensor_datatype(n.output[0], DataType.BIPOLAR)
+            graph_modified = True
+    return (model, graph_modified)
 
 
 def collapse_repeated_op(model, op_name, make_collapsed_param_fxn):
-- 
GitLab