diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index 4c4620da472c5d34985be4054c36099bcc6c811d..1acd4e3abe2d77248810cf15c15475e806a3bd32 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -34,6 +34,7 @@ from finn.transformation import Transformation def _infer_node_datatype(model, node): """Infer output datatype(s) for a particular node. Returns True if any changes were made.""" + dt_identity_optypes = ["Reshape", "Transpose"] idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) op_type = node.op_type @@ -63,6 +64,10 @@ def _infer_node_datatype(model, node): else: odtype = DataType.UINT32 model.set_tensor_datatype(node.output[0], odtype) + elif node.op_type in dt_identity_optypes: + # set output dtype = input dtype + idtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], idtype) else: # unknown, assume node produces float32 outputs for o in node.output: