Skip to content
Snippets Groups Projects
Commit 5d90aa66 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] use new custom op infra for datatype inference

parent 0dccb881
No related branches found
No related tags found
No related merge requests found
import finn.custom_op.registry as registry
from finn.core.datatype import DataType
from finn.core.utils import get_by_name
from finn.transformation import Transformation
......@@ -8,36 +8,37 @@ def _infer_node_datatype(model, node):
changes were made."""
idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
if node.op_type == "MultiThreshold":
op_type = node.op_type
if node.domain == "finn":
# handle DataType inference for CustomOp
try:
odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
model.set_tensor_datatype(node.output[0], DataType[odt])
except AttributeError:
# number of thresholds decides # output bits
# use get_smallest_possible, assuming unsigned
n_thres = model.get_tensor_shape(node.input[1])[1]
odtype = DataType.get_smallest_possible(n_thres)
model.set_tensor_datatype(node.output[0], odtype)
elif node.op_type == "Sign":
# always produces bipolar outputs
model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
elif node.op_type == "MatMul":
if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
# node has at least one float input, output is also float
model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
else:
# TODO compute minimum / maximum result to minimize bitwidth
# use (u)int32 accumulators for now
has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
if has_signed_inp:
odtype = DataType.INT32
else:
odtype = DataType.UINT32
model.set_tensor_datatype(node.output[0], odtype)
# lookup op_type in registry of CustomOps
inst = registry.custom_op[op_type]()
inst.infer_node_datatype(node, model)
except KeyError:
# exception if op_type is not supported
raise Exception("Custom op_type %s is currently not supported." % op_type)
else:
# unknown, assume node produces float32 outputs
for o in node.output:
model.set_tensor_datatype(o, DataType.FLOAT32)
if node.op_type == "Sign":
# always produces bipolar outputs
model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
elif node.op_type == "MatMul":
if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
# node has at least one float input, output is also float
model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
else:
# TODO compute minimum / maximum result to minimize bitwidth
# use (u)int32 accumulators for now
has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
if has_signed_inp:
odtype = DataType.INT32
else:
odtype = DataType.UINT32
model.set_tensor_datatype(node.output[0], odtype)
else:
# unknown, assume node produces float32 outputs
for o in node.output:
model.set_tensor_datatype(o, DataType.FLOAT32)
# compare old and new output dtypes to see if anything changed
new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
graph_modified = new_odtypes != odtypes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment