Skip to content
Snippets Groups Projects
Commit a3ac2200 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Wrapper] respect init datatype while setting shape

parent cbc0fa2d
No related branches found
No related tags found
No related merge requests found
......@@ -137,9 +137,8 @@ class ModelWrapper:
except ValueError:
return None
def set_tensor_shape(self, tensor_name, tensor_shape):
def set_tensor_shape(self, tensor_name, tensor_shape, dtype=TensorProto.FLOAT):
"""Assign shape in ValueInfoProto for tensor with given name."""
dtype = TensorProto.FLOAT
new_vi = oh.make_tensor_value_info(tensor_name, dtype, tensor_shape)
# find what container tis tensor's ValueInfo lives in
# if not found anywhere, we assume it's a new value_info
......@@ -169,7 +168,8 @@ class ModelWrapper:
# create and insert new initializer
graph.initializer.append(tensor_init_proto)
# set shape
self.set_tensor_shape(tensor_name, list(tensor_value.shape))
dtype = tensor_init_proto.data_type
self.set_tensor_shape(tensor_name, list(tensor_value.shape), dtype)
def rename_tensor(self, old_name, new_name):
"""Rename a tensor from old_name to new_name."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment