Skip to content
Snippets Groups Projects
Commit 7fa34eb1 authored by auphelia's avatar auphelia
Browse files

[Transform] Modify InferDataTypes so that if output datatype for std onnx...

[Transform] Modify InferDataTypes so that if output datatype for std onnx nodes is already set, it is not changed
parent 22917141
No related branches found
No related tags found
No related merge requests found
......@@ -71,7 +71,13 @@ def _infer_node_datatype(model, node):
else:
# unknown, assume node produces float32 outputs
for o in node.output:
model.set_tensor_datatype(o, DataType.FLOAT32)
# check if output datatype is already set to a value != FLOAT32
odtype = model.get_tensor_datatype(o)
if odtype is not None and odtype != DataType.FLOAT32:
# don't change data type
model.set_tensor_datatype(o, odtype)
else:
model.set_tensor_datatype(o, DataType.FLOAT32)
# compare old and new output dtypes to see if anything changed
new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
graph_modified = new_odtypes != odtypes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment