diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 1acd4e3abe2d77248810cf15c15475e806a3bd32..39b7a787be8c725e7b6d474757dd96fc4848dfe0 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -71,7 +71,13 @@ def _infer_node_datatype(model, node):
         else:
             # unknown, assume node produces float32 outputs
             for o in node.output:
-                model.set_tensor_datatype(o, DataType.FLOAT32)
+                # check if output datatype is already set to a value != FLOAT32
+                odtype = model.get_tensor_datatype(o)
+                if odtype is not None and odtype != DataType.FLOAT32:
+                    # don't change data type
+                    model.set_tensor_datatype(o, odtype)
+                else:
+                    model.set_tensor_datatype(o, DataType.FLOAT32)
     # compare old and new output dtypes to see if anything changed
     new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     graph_modified = new_odtypes != odtypes