diff --git a/src/finn/core/execute_custom_node.py b/src/finn/core/execute_custom_node.py index 62d8b7fc1ba0ce0d6c83c2b56ef1d75613be974b..e78e9d5de34e013fdbe43aae70c37955b95532b9 100644 --- a/src/finn/core/execute_custom_node.py +++ b/src/finn/core/execute_custom_node.py @@ -1,6 +1,7 @@ # import onnx.helper as helper import finn.core.multithreshold as multiThresh +from finn.core.utils import get_by_name def execute_custom_node(node, context, graph): @@ -11,8 +12,17 @@ def execute_custom_node(node, context, graph): # save inputs v = context[node.input[0]] thresholds = context[node.input[1]] + # retrieve attributes if output scaling is used + try: + out_scale = get_by_name(node.attribute, "out_scale").f + except AttributeError: + out_scale = None + try: + out_bias = get_by_name(node.attribute, "out_bias").f + except AttributeError: + out_bias = None # calculate output - output = multiThresh.execute(v, thresholds) + output = multiThresh.execute(v, thresholds, out_scale, out_bias) # setting context according to output context[node.output[0]] = output diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py index 009259c577879a8aa09ac44ace704af55ca2593d..de38971502f1930dbe2090f3d88aad62e209478a 100755 --- a/src/finn/core/multithreshold.py +++ b/src/finn/core/multithreshold.py @@ -8,7 +8,7 @@ def compare(x, y): return 0.0 -def execute(v, thresholds): +def execute(v, thresholds, out_scale=None, out_bias=None): # the inputs are expected to be in the shape (N,C,H,W) # N : Batch size @@ -21,6 +21,8 @@ def execute(v, thresholds): # if all channels use the same threshold value) # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4) + # the output tensor will be scaled by out_scale and biased by out_bias + # assert threshold shape is_global_threshold = thresholds.shape[0] == 1 assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold @@ -54,5 +56,8 @@ def execute(v, thresholds): for a in range(num_act): # apply successive thresholding to every element of one channel ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a]) - - return ret.reshape(v.shape) + if out_scale is None: + out_scale = 1.0 + if out_bias is None: + out_bias = 0.0 + return out_scale * ret.reshape(v.shape) + out_bias diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index a311012fc6631e76e75a37b8dc4d1b99d21ce7c7..2c6d75a1d029fbfea3d5b408cc6e8eb5f70dc7b5 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -1,4 +1,5 @@ from finn.core.datatype import DataType +from finn.core.utils import get_by_name from finn.transformation import Transformation @@ -8,10 +9,15 @@ def _infer_node_datatype(model, node): idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) if node.op_type == "MultiThreshold": - # number of thresholds decides # output buts, use get_smallest_possible - n_thres = model.get_tensor_shape(node.input[1])[1] - odtype = DataType.get_smallest_possible(n_thres) - model.set_tensor_datatype(node.output[0], odtype) + try: + odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8") + model.set_tensor_datatype(node.output[0], DataType[odt]) + except AttributeError: + # number of thresholds decides # output bits + # use get_smallest_possible, assuming unsigned + n_thres = model.get_tensor_shape(node.input[1])[1] + odtype = DataType.get_smallest_possible(n_thres) + model.set_tensor_datatype(node.output[0], odtype) elif node.op_type == "Sign": # always produces bipolar outputs model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index fb9e530d641063187dd75de30c8ca49936565bae..95e84e9907717aa8106a3a16319202bc29610f65 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -16,50 +16,30 @@ class ConvertSignToThres(Transformation): for n in graph.node: node_ind += 1 if n.op_type == "Sign": + sign_in_name = n.input[0] sign_out_name = n.output[0] # find consumer consumer = model.find_consumer(sign_out_name) assert consumer is not None - # change op type and create threshold - n.op_type = "MultiThreshold" + # create thresholds thres_param_name = model.make_new_valueinfo_name() thres_param = np.asarray([[0]], dtype=np.float32) - n.input.append(thres_param_name) - n.domain = "finn" model.set_initializer(thres_param_name, thres_param) - # convert 0,1 -> -1,+1 with 2*x-1 - out_shape = model.get_tensor_shape(sign_out_name) - # make a mul node - # note how set_initializer or set_tensor_shape is called before - # calling make_new_valueinfo_name again - mul_param_name = model.make_new_valueinfo_name() - model.set_initializer( - mul_param_name, np.asarray([[2]], dtype=np.float32) + # create a new node + mt_node = oh.make_node( + "MultiThreshold", + [sign_in_name, thres_param_name], + [sign_out_name], + domain="finn", + out_scale=2.0, + out_bias=-1.0, + out_dtype="BIPOLAR", ) - mul_out_name = model.make_new_valueinfo_name() - model.set_tensor_shape(mul_out_name, out_shape) - mul_node = oh.make_node( - "Mul", [sign_out_name, mul_param_name], [mul_out_name] - ) - # make an add node - add_param_name = model.make_new_valueinfo_name() - model.set_initializer( - add_param_name, np.asarray([[-1]], dtype=np.float32) - ) - add_out_name = model.make_new_valueinfo_name() - model.set_tensor_shape(add_out_name, out_shape) - add_node = oh.make_node( - "Add", [mul_out_name, add_param_name], [add_out_name] - ) - # add new nodes to graph at correct position - graph.node.insert(node_ind, mul_node) - graph.node.insert(node_ind + 1, add_node) - # rewrite consumer's input - consumer.input[0] = add_out_name + # remove old node, add new node to graph at correct position + graph.node.insert(node_ind, mt_node) + graph.node.remove(n) # add quantization annotations - model.set_tensor_datatype(sign_out_name, DataType.BINARY) - model.set_tensor_datatype(mul_out_name, DataType.UINT2) - model.set_tensor_datatype(add_out_name, DataType.BIPOLAR) + model.set_tensor_datatype(sign_out_name, DataType.BIPOLAR) graph_modified = True return (model, graph_modified) diff --git a/tests/test_custom_onnx_exec.py b/tests/test_custom_onnx_exec.py index 0d07b9888d1330753446d589619149c1f8f316cf..e1ff552572e8a6d4d55e204cd21f17e4984ce30d 100644 --- a/tests/test_custom_onnx_exec.py +++ b/tests/test_custom_onnx_exec.py @@ -211,3 +211,18 @@ def test_execute_custom_node_multithreshold(): ) assert (execution_context["out"] == outputs).all() + + # test the optional output scaling features on MultiThreshold + node_def = helper.make_node( + "MultiThreshold", + ["v", "thresholds"], + ["out"], + domain="finn", + out_scale=2.0, + out_bias=-1.0, + ) + + graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + outputs_scaled = 2.0 * outputs - 1.0 + assert (execution_context["out"] == outputs_scaled).all() diff --git a/tests/test_multi_thresholding.py b/tests/test_multi_thresholding.py index 49305e5572f35eb4a2e5f7678c73038777eb8b92..5fd7b4309d68ad568773dbdb39d0df979a767e73 100644 --- a/tests/test_multi_thresholding.py +++ b/tests/test_multi_thresholding.py @@ -197,3 +197,7 @@ def test_execute_multi_thresholding(): results = multi_thresh.execute(inputs, thresholds) assert (results == outputs).all() + + results_scaled = multi_thresh.execute(inputs, thresholds, 2.0, -1.0) + outputs_scaled = 2.0 * outputs - 1.0 + assert (results_scaled == outputs_scaled).all()