From 7fa34eb102818aa88fa085a315505cee3c67bb84 Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Fri, 12 Jun 2020 09:45:57 +0100 Subject: [PATCH] [Transform] Modify InferDataTypes so that if output datatype for std onnx nodes is already set, it is not changed --- src/finn/transformation/infer_datatypes.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index 1acd4e3ab..39b7a787b 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -71,7 +71,13 @@ def _infer_node_datatype(model, node): else: # unknown, assume node produces float32 outputs for o in node.output: - model.set_tensor_datatype(o, DataType.FLOAT32) + # check if output datatype is already set to a value != FLOAT32 + odtype = model.get_tensor_datatype(o) + if odtype is not None and odtype != DataType.FLOAT32: + # don't change data type + model.set_tensor_datatype(o, odtype) + else: + model.set_tensor_datatype(o, DataType.FLOAT32) # compare old and new output dtypes to see if anything changed new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) graph_modified = new_odtypes != odtypes -- GitLab