From ba43d9c9f3851bc4fdfc47d917dac49159dadfaf Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Fri, 1 Nov 2019 14:25:49 +0000 Subject: [PATCH] [Transform] add convert_sign_to_thres --- src/finn/transformation/streamline.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index f54b78f04..462005caa 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -2,6 +2,24 @@ import numpy as np from onnx import helper as oh import finn.transformation.infer_shapes as si +from finn.core.datatype import DataType + + +def convert_sign_to_thres(model): + """Convert Sign node instances to MultiThreshold with threshold at 0.""" + graph = model.graph + graph_modified = False + for n in graph.node: + if n.op_type == "Sign": + n.op_type = "MultiThreshold" + thres_param_name = model.make_new_valueinfo_name() + thres_param = np.asarray([[0]], dtype=np.float32) + n.input.append(thres_param_name) + model.set_initializer(thres_param_name, thres_param) + # mark output tensor as bipolar + model.set_tensor_datatype(n.output[0], DataType.BIPOLAR) + graph_modified = True + return (model, graph_modified) def collapse_repeated_op(model, op_name, make_collapsed_param_fxn): -- GitLab