From 7fa34eb102818aa88fa085a315505cee3c67bb84 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Fri, 12 Jun 2020 09:45:57 +0100
Subject: [PATCH] [Transform] Modify InferDataTypes so that if output datatype
 for std onnx nodes is already set, it is not changed

---
 src/finn/transformation/infer_datatypes.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 1acd4e3ab..39b7a787b 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -71,7 +71,13 @@ def _infer_node_datatype(model, node):
         else:
             # unknown, assume node produces float32 outputs
             for o in node.output:
-                model.set_tensor_datatype(o, DataType.FLOAT32)
+                # check if output datatype is already set to a value != FLOAT32
+                odtype = model.get_tensor_datatype(o)
+                if odtype is not None and odtype != DataType.FLOAT32:
+                    # don't change data type
+                    model.set_tensor_datatype(o, odtype)
+                else:
+                    model.set_tensor_datatype(o, DataType.FLOAT32)
     # compare old and new output dtypes to see if anything changed
     new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     graph_modified = new_odtypes != odtypes
-- 
GitLab