Skip to content
Snippets Groups Projects
Commit be1ef4f4 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] use bipolar MultiThreshold in ConvertSignToThres

parent e8199bbf
No related branches found
No related tags found
No related merge requests found
......@@ -16,50 +16,30 @@ class ConvertSignToThres(Transformation):
for n in graph.node:
node_ind += 1
if n.op_type == "Sign":
sign_in_name = n.input[0]
sign_out_name = n.output[0]
# find consumer
consumer = model.find_consumer(sign_out_name)
assert consumer is not None
# change op type and create threshold
n.op_type = "MultiThreshold"
# create thresholds
thres_param_name = model.make_new_valueinfo_name()
thres_param = np.asarray([[0]], dtype=np.float32)
n.input.append(thres_param_name)
n.domain = "finn"
model.set_initializer(thres_param_name, thres_param)
# convert 0,1 -> -1,+1 with 2*x-1
out_shape = model.get_tensor_shape(sign_out_name)
# make a mul node
# note how set_initializer or set_tensor_shape is called before
# calling make_new_valueinfo_name again
mul_param_name = model.make_new_valueinfo_name()
model.set_initializer(
mul_param_name, np.asarray([[2]], dtype=np.float32)
# create a new node
mt_node = oh.make_node(
"MultiThreshold",
[sign_in_name, thres_param_name],
[sign_out_name],
domain="finn",
out_scale=2.0,
out_bias=-1.0,
out_dtype="BIPOLAR",
)
mul_out_name = model.make_new_valueinfo_name()
model.set_tensor_shape(mul_out_name, out_shape)
mul_node = oh.make_node(
"Mul", [sign_out_name, mul_param_name], [mul_out_name]
)
# make an add node
add_param_name = model.make_new_valueinfo_name()
model.set_initializer(
add_param_name, np.asarray([[-1]], dtype=np.float32)
)
add_out_name = model.make_new_valueinfo_name()
model.set_tensor_shape(add_out_name, out_shape)
add_node = oh.make_node(
"Add", [mul_out_name, add_param_name], [add_out_name]
)
# add new nodes to graph at correct position
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
# rewrite consumer's input
consumer.input[0] = add_out_name
# remove old node, add new node to graph at correct position
graph.node.insert(node_ind, mt_node)
graph.node.remove(n)
# add quantization annotations
model.set_tensor_datatype(sign_out_name, DataType.BINARY)
model.set_tensor_datatype(mul_out_name, DataType.UINT2)
model.set_tensor_datatype(add_out_name, DataType.BIPOLAR)
model.set_tensor_datatype(sign_out_name, DataType.BIPOLAR)
graph_modified = True
return (model, graph_modified)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment