Skip to content
Snippets Groups Projects
Unverified Commit 58b98037 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

[Transform] preserve scalar odt in MoveScalarLinearPastInvariants

parent 73087901
No related branches found
No related tags found
No related merge requests found
......@@ -545,6 +545,7 @@ class MoveScalarLinearPastInvariants(Transformation):
# move prod0 from input to output,
old_prod0_in = prod0.input[0]
old_prod0_out = prod0.output[0]
scalar_op_odt = model.get_tensor_datatype(old_prod0_out)
old_n_out = n.output[0]
in_shape = model.get_tensor_shape(n.input[0])
out_shape = model.get_tensor_shape(n.output[0])
......@@ -555,6 +556,8 @@ class MoveScalarLinearPastInvariants(Transformation):
model.set_tensor_shape(n.input[0], in_shape)
model.set_tensor_shape(n.output[0], out_shape)
model.set_tensor_shape(prod0.output[0], out_shape)
model.set_tensor_datatype(prod0.output[0], scalar_op_odt)
model.set_tensor_datatype(n.output[0], DataType.FLOAT32)
graph.node.remove(prod0)
graph.node.insert(node_ind - 1, prod0)
graph_modified = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment