Skip to content
Snippets Groups Projects
Commit ba43d9c9 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add convert_sign_to_thres

parent 26d1088e
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,24 @@ import numpy as np
from onnx import helper as oh
import finn.transformation.infer_shapes as si
from finn.core.datatype import DataType
def convert_sign_to_thres(model):
"""Convert Sign node instances to MultiThreshold with threshold at 0."""
graph = model.graph
graph_modified = False
for n in graph.node:
if n.op_type == "Sign":
n.op_type = "MultiThreshold"
thres_param_name = model.make_new_valueinfo_name()
thres_param = np.asarray([[0]], dtype=np.float32)
n.input.append(thres_param_name)
model.set_initializer(thres_param_name, thres_param)
# mark output tensor as bipolar
model.set_tensor_datatype(n.output[0], DataType.BIPOLAR)
graph_modified = True
return (model, graph_modified)
def collapse_repeated_op(model, op_name, make_collapsed_param_fxn):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment