diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 4c4620da472c5d34985be4054c36099bcc6c811d..1acd4e3abe2d77248810cf15c15475e806a3bd32 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -34,6 +34,7 @@ from finn.transformation import Transformation
 def _infer_node_datatype(model, node):
     """Infer output datatype(s) for a particular node. Returns True if any
     changes were made."""
+    dt_identity_optypes = ["Reshape", "Transpose"]
     idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
     odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     op_type = node.op_type
@@ -63,6 +64,10 @@ def _infer_node_datatype(model, node):
                 else:
                     odtype = DataType.UINT32
                 model.set_tensor_datatype(node.output[0], odtype)
+        elif node.op_type in dt_identity_optypes:
+            # set output dtype = input dtype
+            idtype = model.get_tensor_datatype(node.input[0])
+            model.set_tensor_datatype(node.output[0], idtype)
         else:
             # unknown, assume node produces float32 outputs
             for o in node.output: