Skip to content
Snippets Groups Projects
Commit ed535029 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/multithreshold_odtype' into dev

parents e709d19d f7fc06cb
No related branches found
No related tags found
No related merge requests found
# import onnx.helper as helper
import finn.core.multithreshold as multiThresh
from finn.core.utils import get_by_name
def execute_custom_node(node, context, graph):
......@@ -11,8 +12,17 @@ def execute_custom_node(node, context, graph):
# save inputs
v = context[node.input[0]]
thresholds = context[node.input[1]]
# retrieve attributes if output scaling is used
try:
out_scale = get_by_name(node.attribute, "out_scale").f
except AttributeError:
out_scale = None
try:
out_bias = get_by_name(node.attribute, "out_bias").f
except AttributeError:
out_bias = None
# calculate output
output = multiThresh.execute(v, thresholds)
output = multiThresh.execute(v, thresholds, out_scale, out_bias)
# setting context according to output
context[node.output[0]] = output
......
......@@ -8,7 +8,7 @@ def compare(x, y):
return 0.0
def execute(v, thresholds):
def execute(v, thresholds, out_scale=None, out_bias=None):
# the inputs are expected to be in the shape (N,C,H,W)
# N : Batch size
......@@ -21,6 +21,8 @@ def execute(v, thresholds):
# if all channels use the same threshold value)
# B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4)
# the output tensor will be scaled by out_scale and biased by out_bias
# assert threshold shape
is_global_threshold = thresholds.shape[0] == 1
assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
......@@ -54,5 +56,8 @@ def execute(v, thresholds):
for a in range(num_act):
# apply successive thresholding to every element of one channel
ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
return ret.reshape(v.shape)
if out_scale is None:
out_scale = 1.0
if out_bias is None:
out_bias = 0.0
return out_scale * ret.reshape(v.shape) + out_bias
from finn.core.datatype import DataType
from finn.core.utils import get_by_name
from finn.transformation import Transformation
......@@ -8,10 +9,15 @@ def _infer_node_datatype(model, node):
idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
if node.op_type == "MultiThreshold":
# number of thresholds decides # output buts, use get_smallest_possible
n_thres = model.get_tensor_shape(node.input[1])[1]
odtype = DataType.get_smallest_possible(n_thres)
model.set_tensor_datatype(node.output[0], odtype)
try:
odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
model.set_tensor_datatype(node.output[0], DataType[odt])
except AttributeError:
# number of thresholds decides # output bits
# use get_smallest_possible, assuming unsigned
n_thres = model.get_tensor_shape(node.input[1])[1]
odtype = DataType.get_smallest_possible(n_thres)
model.set_tensor_datatype(node.output[0], odtype)
elif node.op_type == "Sign":
# always produces bipolar outputs
model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
......
......@@ -16,50 +16,30 @@ class ConvertSignToThres(Transformation):
for n in graph.node:
node_ind += 1
if n.op_type == "Sign":
sign_in_name = n.input[0]
sign_out_name = n.output[0]
# find consumer
consumer = model.find_consumer(sign_out_name)
assert consumer is not None
# change op type and create threshold
n.op_type = "MultiThreshold"
# create thresholds
thres_param_name = model.make_new_valueinfo_name()
thres_param = np.asarray([[0]], dtype=np.float32)
n.input.append(thres_param_name)
n.domain = "finn"
model.set_initializer(thres_param_name, thres_param)
# convert 0,1 -> -1,+1 with 2*x-1
out_shape = model.get_tensor_shape(sign_out_name)
# make a mul node
# note how set_initializer or set_tensor_shape is called before
# calling make_new_valueinfo_name again
mul_param_name = model.make_new_valueinfo_name()
model.set_initializer(
mul_param_name, np.asarray([[2]], dtype=np.float32)
# create a new node
mt_node = oh.make_node(
"MultiThreshold",
[sign_in_name, thres_param_name],
[sign_out_name],
domain="finn",
out_scale=2.0,
out_bias=-1.0,
out_dtype="BIPOLAR",
)
mul_out_name = model.make_new_valueinfo_name()
model.set_tensor_shape(mul_out_name, out_shape)
mul_node = oh.make_node(
"Mul", [sign_out_name, mul_param_name], [mul_out_name]
)
# make an add node
add_param_name = model.make_new_valueinfo_name()
model.set_initializer(
add_param_name, np.asarray([[-1]], dtype=np.float32)
)
add_out_name = model.make_new_valueinfo_name()
model.set_tensor_shape(add_out_name, out_shape)
add_node = oh.make_node(
"Add", [mul_out_name, add_param_name], [add_out_name]
)
# add new nodes to graph at correct position
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
# rewrite consumer's input
consumer.input[0] = add_out_name
# remove old node, add new node to graph at correct position
graph.node.insert(node_ind, mt_node)
graph.node.remove(n)
# add quantization annotations
model.set_tensor_datatype(sign_out_name, DataType.BINARY)
model.set_tensor_datatype(mul_out_name, DataType.UINT2)
model.set_tensor_datatype(add_out_name, DataType.BIPOLAR)
model.set_tensor_datatype(sign_out_name, DataType.BIPOLAR)
graph_modified = True
return (model, graph_modified)
......
......@@ -211,3 +211,18 @@ def test_execute_custom_node_multithreshold():
)
assert (execution_context["out"] == outputs).all()
# test the optional output scaling features on MultiThreshold
node_def = helper.make_node(
"MultiThreshold",
["v", "thresholds"],
["out"],
domain="finn",
out_scale=2.0,
out_bias=-1.0,
)
graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
outputs_scaled = 2.0 * outputs - 1.0
assert (execution_context["out"] == outputs_scaled).all()
......@@ -197,3 +197,7 @@ def test_execute_multi_thresholding():
results = multi_thresh.execute(inputs, thresholds)
assert (results == outputs).all()
results_scaled = multi_thresh.execute(inputs, thresholds, 2.0, -1.0)
outputs_scaled = 2.0 * outputs - 1.0
assert (results_scaled == outputs_scaled).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment