Skip to content
Snippets Groups Projects
Commit a9112a73 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] convert_sign_to thres needs 2x-1 after MultiThreshold

parent 29767241
No related branches found
No related tags found
No related merge requests found
......@@ -9,16 +9,50 @@ def convert_sign_to_thres(model):
"""Convert Sign node instances to MultiThreshold with threshold at 0."""
graph = model.graph
graph_modified = False
node_ind = 0
for n in graph.node:
node_ind += 1
if n.op_type == "Sign":
sign_out_name = n.output[0]
# find consumer
consumer = model.find_consumer(sign_out_name)
assert consumer is not None
# change op type and create threshold
n.op_type = "MultiThreshold"
thres_param_name = model.make_new_valueinfo_name()
thres_param = np.asarray([[0]], dtype=np.float32)
n.input.append(thres_param_name)
n.domain = "finn"
model.set_initializer(thres_param_name, thres_param)
# mark output tensor as bipolar
model.set_tensor_datatype(n.output[0], DataType.BIPOLAR)
# convert 0,1 -> -1,+1 with 2*x-1
out_shape = model.get_tensor_shape(sign_out_name)
# make a mul node
# note how set_initializer or set_tensor_shape is called before
# calling make_new_valueinfo_name again
mul_param_name = model.make_new_valueinfo_name()
model.set_initializer(mul_param_name, np.asarray([[2]], dtype=np.float32))
mul_out_name = model.make_new_valueinfo_name()
model.set_tensor_shape(mul_out_name, out_shape)
mul_node = oh.make_node(
"Mul", [sign_out_name, mul_param_name], [mul_out_name]
)
# make an add node
add_param_name = model.make_new_valueinfo_name()
model.set_initializer(add_param_name, np.asarray([[-1]], dtype=np.float32))
add_out_name = model.make_new_valueinfo_name()
model.set_tensor_shape(add_out_name, out_shape)
add_node = oh.make_node(
"Add", [mul_out_name, add_param_name], [add_out_name]
)
# add new nodes to graph at correct position
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
# rewrite consumer's input
consumer.input[0] = add_out_name
# add quantization annotations
model.set_tensor_datatype(sign_out_name, DataType.BINARY)
model.set_tensor_datatype(mul_out_name, DataType.UINT2)
model.set_tensor_datatype(add_out_name, DataType.BIPOLAR)
graph_modified = True
return (model, graph_modified)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment