diff --git a/src/finn/core/execute_custom_node.py b/src/finn/core/execute_custom_node.py index e78e9d5de34e013fdbe43aae70c37955b95532b9..1923a64237c3277b30bb0bbc8276858203e9ef28 100644 --- a/src/finn/core/execute_custom_node.py +++ b/src/finn/core/execute_custom_node.py @@ -1,31 +1,16 @@ # import onnx.helper as helper -import finn.core.multithreshold as multiThresh -from finn.core.utils import get_by_name +import finn.custom_op.registry as registry def execute_custom_node(node, context, graph): """Call custom implementation to execute a single custom node. Input/output provided via context.""" - - if node.op_type == "MultiThreshold": - # save inputs - v = context[node.input[0]] - thresholds = context[node.input[1]] - # retrieve attributes if output scaling is used - try: - out_scale = get_by_name(node.attribute, "out_scale").f - except AttributeError: - out_scale = None - try: - out_bias = get_by_name(node.attribute, "out_bias").f - except AttributeError: - out_bias = None - # calculate output - output = multiThresh.execute(v, thresholds, out_scale, out_bias) - # setting context according to output - context[node.output[0]] = output - - else: + op_type = node.op_type + try: + # lookup op_type in registry of CustomOps + inst = registry.custom_op[op_type]() + inst.execute_node(node, context, graph) + except KeyError: # exception if op_type is not supported - raise Exception("This custom node is currently not supported.") + raise Exception("Custom op_type %s is currently not supported." % op_type) diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py deleted file mode 100755 index de38971502f1930dbe2090f3d88aad62e209478a..0000000000000000000000000000000000000000 --- a/src/finn/core/multithreshold.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np - - -def compare(x, y): - if x >= y: - return 1.0 - else: - return 0.0 - - -def execute(v, thresholds, out_scale=None, out_bias=None): - - # the inputs are expected to be in the shape (N,C,H,W) - # N : Batch size - # C : Number of channels - # H : Heigth of the input images - # W : Width of the input images - # - # the thresholds are expected to be in the shape (C, B) - # C : Number of channels (must be the same value as C in input tensor or 1 - # if all channels use the same threshold value) - # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4) - - # the output tensor will be scaled by out_scale and biased by out_bias - - # assert threshold shape - is_global_threshold = thresholds.shape[0] == 1 - assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold - - # save the required shape sizes for the loops (N, C and B) - num_batch = v.shape[0] - num_channel = v.shape[1] - - num_act = thresholds.shape[1] - - # reshape inputs to enable channel-wise reading - vr = v.reshape((v.shape[0], v.shape[1], -1)) - - # save the new shape size of the images - num_img_elem = vr.shape[2] - - # initiate output tensor - ret = np.zeros_like(vr) - - # iterate over thresholds channel-wise - for t in range(num_channel): - channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] - - # iterate over batches - for b in range(num_batch): - - # iterate over image elements on which the thresholds should be applied - for elem in range(num_img_elem): - - # iterate over the different thresholds that correspond to one channel - for a in range(num_act): - # apply successive thresholding to every element of one channel - ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a]) - if out_scale is None: - out_scale = 1.0 - if out_bias is None: - out_bias = 0.0 - return out_scale * ret.reshape(v.shape) + out_bias diff --git a/src/finn/custom_op/__init__.py b/src/finn/custom_op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd7024affedd15eb564b4e588a2b79343e95d371 --- /dev/null +++ b/src/finn/custom_op/__init__.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + + +class CustomOp(ABC): + def __init__(self): + super().__init__() + + @abstractmethod + def make_shape_compatible_op(self, node): + pass + + @abstractmethod + def infer_node_datatype(self, node, model): + pass + + @abstractmethod + def execute_node(self, node, context, graph): + pass diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8f1ad4ceded8a40e182af5d79ccd803e0b8c33 --- /dev/null +++ b/src/finn/custom_op/multithreshold.py @@ -0,0 +1,91 @@ +import numpy as np +import onnx.helper as helper + +from finn.core.datatype import DataType +from finn.core.utils import get_by_name +from finn.custom_op import CustomOp + + +class MultiThreshold(CustomOp): + def make_shape_compatible_op(self, node): + return helper.make_node("Relu", [node.input[0]], [node.output[0]]) + + def infer_node_datatype(self, node, model): + try: + odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8") + model.set_tensor_datatype(node.output[0], DataType[odt]) + except AttributeError: + # number of thresholds decides # output bits + # use get_smallest_possible, assuming unsigned + n_thres = model.get_tensor_shape(node.input[1])[1] + odtype = DataType.get_smallest_possible(n_thres) + model.set_tensor_datatype(node.output[0], odtype) + + def execute_node(self, node, context, graph): + # save inputs + v = context[node.input[0]] + thresholds = context[node.input[1]] + # retrieve attributes if output scaling is used + try: + out_scale = get_by_name(node.attribute, "out_scale").f + except AttributeError: + out_scale = None + try: + out_bias = get_by_name(node.attribute, "out_bias").f + except AttributeError: + out_bias = None + # calculate output + output = self._execute(v, thresholds, out_scale, out_bias) + # setting context according to output + context[node.output[0]] = output + + def _compare(self, x, y): + if x >= y: + return 1.0 + else: + return 0.0 + + def _execute(self, v, thresholds, out_scale=None, out_bias=None): + # the inputs are expected to be in the shape (N,C,H,W) + # N : Batch size + # C : Number of channels + # H : Heigth of the input images + # W : Width of the input images + # + # the thresholds are expected to be in the shape (C, B) + # C : Number of channels (must be the same value as C in input tensor + # or 1 if all channels use the same threshold value) + # B : Desired activation steps => i.e. for 4-bit activation, + # B=7 (2^(n)-1 and n=4) + # the output tensor will be scaled by out_scale and biased by out_bias + # assert threshold shape + is_global_threshold = thresholds.shape[0] == 1 + assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold + # save the required shape sizes for the loops (N, C and B) + num_batch = v.shape[0] + num_channel = v.shape[1] + num_act = thresholds.shape[1] + # reshape inputs to enable channel-wise reading + vr = v.reshape((v.shape[0], v.shape[1], -1)) + # save the new shape size of the images + num_img_elem = vr.shape[2] + # initiate output tensor + ret = np.zeros_like(vr) + # iterate over thresholds channel-wise + for t in range(num_channel): + channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] + # iterate over batches + for b in range(num_batch): + # iterate over image elements on which the thresholds will be applied + for elem in range(num_img_elem): + # iterate over the different thresholds for one channel + for a in range(num_act): + # apply successive thresholding to every element + ret[b][t][elem] += self._compare( + vr[b][t][elem], channel_thresh[a] + ) + if out_scale is None: + out_scale = 1.0 + if out_bias is None: + out_bias = 0.0 + return out_scale * ret.reshape(v.shape) + out_bias diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..07bdbf1cab6848dc25b7bb556000829617620747 --- /dev/null +++ b/src/finn/custom_op/registry.py @@ -0,0 +1,8 @@ +# make sure new CustomOp subclasses are imported here so that they get +# registered and plug in correctly into the infrastructure +from finn.custom_op.multithreshold import MultiThreshold + +# create a mapping of all known CustomOp names and classes +custom_op = {} + +custom_op["MultiThreshold"] = MultiThreshold diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index 2c6d75a1d029fbfea3d5b408cc6e8eb5f70dc7b5..60ea43c97d0298969ab4b3a280ed9bd3f62cbab8 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -1,5 +1,5 @@ +import finn.custom_op.registry as registry from finn.core.datatype import DataType -from finn.core.utils import get_by_name from finn.transformation import Transformation @@ -8,36 +8,37 @@ def _infer_node_datatype(model, node): changes were made.""" idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) - if node.op_type == "MultiThreshold": + op_type = node.op_type + if node.domain == "finn": + # handle DataType inference for CustomOp try: - odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8") - model.set_tensor_datatype(node.output[0], DataType[odt]) - except AttributeError: - # number of thresholds decides # output bits - # use get_smallest_possible, assuming unsigned - n_thres = model.get_tensor_shape(node.input[1])[1] - odtype = DataType.get_smallest_possible(n_thres) - model.set_tensor_datatype(node.output[0], odtype) - elif node.op_type == "Sign": - # always produces bipolar outputs - model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) - elif node.op_type == "MatMul": - if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0: - # node has at least one float input, output is also float - model.set_tensor_datatype(node.output[0], DataType.FLOAT32) - else: - # TODO compute minimum / maximum result to minimize bitwidth - # use (u)int32 accumulators for now - has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0 - if has_signed_inp: - odtype = DataType.INT32 - else: - odtype = DataType.UINT32 - model.set_tensor_datatype(node.output[0], odtype) + # lookup op_type in registry of CustomOps + inst = registry.custom_op[op_type]() + inst.infer_node_datatype(node, model) + except KeyError: + # exception if op_type is not supported + raise Exception("Custom op_type %s is currently not supported." % op_type) else: - # unknown, assume node produces float32 outputs - for o in node.output: - model.set_tensor_datatype(o, DataType.FLOAT32) + if node.op_type == "Sign": + # always produces bipolar outputs + model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) + elif node.op_type == "MatMul": + if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0: + # node has at least one float input, output is also float + model.set_tensor_datatype(node.output[0], DataType.FLOAT32) + else: + # TODO compute minimum / maximum result to minimize bitwidth + # use (u)int32 accumulators for now + has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0 + if has_signed_inp: + odtype = DataType.INT32 + else: + odtype = DataType.UINT32 + model.set_tensor_datatype(node.output[0], odtype) + else: + # unknown, assume node produces float32 outputs + for o in node.output: + model.set_tensor_datatype(o, DataType.FLOAT32) # compare old and new output dtypes to see if anything changed new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) graph_modified = new_odtypes != odtypes diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py index 3e79587d4d6b9d021ddd11b3edb11b42e73c1297..2772ae38524f8339d7b66825ff05656aae4f606d 100644 --- a/src/finn/transformation/infer_shapes.py +++ b/src/finn/transformation/infer_shapes.py @@ -1,6 +1,6 @@ -import onnx.helper as helper import onnx.shape_inference as si +import finn.custom_op.registry as registry from finn.core.modelwrapper import ModelWrapper from finn.transformation import Transformation @@ -9,10 +9,14 @@ def _make_shape_compatible_op(node): """Return a shape-compatible non-FINN op for a given FINN op. Used for shape inference with custom ops.""" assert node.domain == "finn" - if node.op_type == "MultiThreshold": - return helper.make_node("Relu", [node.input[0]], [node.output[0]]) - else: - raise Exception("No known shape-compatible op for %s" % node.op_type) + op_type = node.op_type + try: + # lookup op_type in registry of CustomOps + inst = registry.custom_op[op_type]() + return inst.make_shape_compatible_op(node) + except KeyError: + # exception if op_type is not supported + raise Exception("Custom op_type %s is currently not supported." % op_type) def _hide_finn_ops(model): diff --git a/tests/test_multi_thresholding.py b/tests/test_multi_thresholding.py index 5fd7b4309d68ad568773dbdb39d0df979a767e73..b16d1602a65284ba366e9c331370c15096981ece 100644 --- a/tests/test_multi_thresholding.py +++ b/tests/test_multi_thresholding.py @@ -1,6 +1,6 @@ import numpy as np -import finn.core.multithreshold as multi_thresh +from finn.custom_op.multithreshold import MultiThreshold def test_execute_multi_thresholding(): @@ -194,10 +194,11 @@ def test_execute_multi_thresholding(): ), ) - results = multi_thresh.execute(inputs, thresholds) + multi_thresh = MultiThreshold() + results = multi_thresh._execute(inputs, thresholds) assert (results == outputs).all() - results_scaled = multi_thresh.execute(inputs, thresholds, 2.0, -1.0) + results_scaled = multi_thresh._execute(inputs, thresholds, 2.0, -1.0) outputs_scaled = 2.0 * outputs - 1.0 assert (results_scaled == outputs_scaled).all()