From 68c7f4e05291edd438bb026838c5cc823f368aea Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 7 Nov 2019 15:46:24 +0000 Subject: [PATCH] [Transform] add a first version of infer_datatypes --- src/finn/transformation/infer_datatypes.py | 47 ++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/finn/transformation/infer_datatypes.py diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py new file mode 100644 index 000000000..089d0ff81 --- /dev/null +++ b/src/finn/transformation/infer_datatypes.py @@ -0,0 +1,47 @@ +from finn.core.datatype import DataType + + +def _infer_node_datatype(model, node): + """Infer output datatype(s) for a particular node. Returns True if any + changes were made.""" + idtypes = map(lambda x: model.get_tensor_datatype(x), node.input) + odtypes = map(lambda x: model.get_tensor_datatype(x), node.output) + if node.op_type == "MultiThreshold": + # number of thresholds decides # output buts, use get_smallest_possible + n_thres = model.get_tensor_shape(node.input[1])[1] + odtype = DataType.get_smallest_possible(n_thres) + model.set_tensor_datatype(node.output[0], odtype) + elif node.op_type == "Sign": + # always produces bipolar outputs + model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) + elif node.op_type == "MatMul": + if len(filter(lambda x: x == DataType.FLOAT32, idtypes)) != 0: + # node has at least one float input, output is also float + model.set_tensor_datatype(node.output[0], DataType.FLOAT32) + else: + # TODO compute minimum / maximum result to minimize bitwidth + # use (u)int32 accumulators for now + has_signed_inp = len(filter(lambda x: x.signed(), idtypes)) != 0 + if has_signed_inp: + odtype = DataType.INT32 + else: + odtype = DataType.UINT32 + model.set_tensor_datatype(node.output[0], odtype) + else: + # unknown, assume node produces float32 outputs + for o in node.output: + model.set_tensor_datatype(o, DataType.FLOAT32) + # compare old and new output dtypes to see if anything changed + new_odtypes = map(lambda x: model.get_tensor_datatype(x), node.output) + graph_modified = new_odtypes == odtypes + return graph_modified + + +def infer_datatypes(model): + """Infer FINN DataType info for all intermediate/output tensors based on + inputs and node type.""" + graph = model.graph + graph_modified = False + for node in graph.node: + graph_modified |= _infer_node_datatype(model, node) + return (model, graph_modified) -- GitLab