From 4f363c3d30f39b49c820ab74d1aeee42f82d6a57 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 24 Mar 2020 13:49:37 +0000
Subject: [PATCH] [DataTypeInf] infer odt=idt for certain node types

---
 src/finn/transformation/infer_datatypes.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 4c4620da4..1acd4e3ab 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -34,6 +34,7 @@ from finn.transformation import Transformation
 def _infer_node_datatype(model, node):
     """Infer output datatype(s) for a particular node. Returns True if any
     changes were made."""
+    dt_identity_optypes = ["Reshape", "Transpose"]
     idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
     odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     op_type = node.op_type
@@ -63,6 +64,10 @@ def _infer_node_datatype(model, node):
                 else:
                     odtype = DataType.UINT32
                 model.set_tensor_datatype(node.output[0], odtype)
+        elif node.op_type in dt_identity_optypes:
+            # set output dtype = input dtype
+            idtype = model.get_tensor_datatype(node.input[0])
+            model.set_tensor_datatype(node.output[0], idtype)
         else:
             # unknown, assume node produces float32 outputs
             for o in node.output:
-- 
GitLab