diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index 2c6d75a1d029fbfea3d5b408cc6e8eb5f70dc7b5..60ea43c97d0298969ab4b3a280ed9bd3f62cbab8 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -1,5 +1,5 @@ +import finn.custom_op.registry as registry from finn.core.datatype import DataType -from finn.core.utils import get_by_name from finn.transformation import Transformation @@ -8,36 +8,37 @@ def _infer_node_datatype(model, node): changes were made.""" idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) - if node.op_type == "MultiThreshold": + op_type = node.op_type + if node.domain == "finn": + # handle DataType inference for CustomOp try: - odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8") - model.set_tensor_datatype(node.output[0], DataType[odt]) - except AttributeError: - # number of thresholds decides # output bits - # use get_smallest_possible, assuming unsigned - n_thres = model.get_tensor_shape(node.input[1])[1] - odtype = DataType.get_smallest_possible(n_thres) - model.set_tensor_datatype(node.output[0], odtype) - elif node.op_type == "Sign": - # always produces bipolar outputs - model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) - elif node.op_type == "MatMul": - if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0: - # node has at least one float input, output is also float - model.set_tensor_datatype(node.output[0], DataType.FLOAT32) - else: - # TODO compute minimum / maximum result to minimize bitwidth - # use (u)int32 accumulators for now - has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0 - if has_signed_inp: - odtype = DataType.INT32 - else: - odtype = DataType.UINT32 - model.set_tensor_datatype(node.output[0], odtype) + # lookup op_type in registry of CustomOps + inst = registry.custom_op[op_type]() + inst.infer_node_datatype(node, model) + except KeyError: + # exception if op_type is not supported + raise Exception("Custom op_type %s is currently not supported." % op_type) else: - # unknown, assume node produces float32 outputs - for o in node.output: - model.set_tensor_datatype(o, DataType.FLOAT32) + if node.op_type == "Sign": + # always produces bipolar outputs + model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) + elif node.op_type == "MatMul": + if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0: + # node has at least one float input, output is also float + model.set_tensor_datatype(node.output[0], DataType.FLOAT32) + else: + # TODO compute minimum / maximum result to minimize bitwidth + # use (u)int32 accumulators for now + has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0 + if has_signed_inp: + odtype = DataType.INT32 + else: + odtype = DataType.UINT32 + model.set_tensor_datatype(node.output[0], odtype) + else: + # unknown, assume node produces float32 outputs + for o in node.output: + model.set_tensor_datatype(o, DataType.FLOAT32) # compare old and new output dtypes to see if anything changed new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) graph_modified = new_odtypes != odtypes