From 4f363c3d30f39b49c820ab74d1aeee42f82d6a57 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 24 Mar 2020 13:49:37 +0000 Subject: [PATCH] [DataTypeInf] infer odt=idt for certain node types --- src/finn/transformation/infer_datatypes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index 4c4620da4..1acd4e3ab 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -34,6 +34,7 @@ from finn.transformation import Transformation def _infer_node_datatype(model, node): """Infer output datatype(s) for a particular node. Returns True if any changes were made.""" + dt_identity_optypes = ["Reshape", "Transpose"] idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) op_type = node.op_type @@ -63,6 +64,10 @@ def _infer_node_datatype(model, node): else: odtype = DataType.UINT32 model.set_tensor_datatype(node.output[0], odtype) + elif node.op_type in dt_identity_optypes: + # set output dtype = input dtype + idtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], idtype) else: # unknown, assume node produces float32 outputs for o in node.output: -- GitLab