From 58b980379f248040fa5e9660f15b758c6ac1e379 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Fri, 10 Jul 2020 16:13:58 +0100 Subject: [PATCH] [Transform] preserve scalar odt in MoveScalarLinearPastInvariants --- src/finn/transformation/streamline/reorder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 452366607..b47f269dd 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -545,6 +545,7 @@ class MoveScalarLinearPastInvariants(Transformation): # move prod0 from input to output, old_prod0_in = prod0.input[0] old_prod0_out = prod0.output[0] + scalar_op_odt = model.get_tensor_datatype(old_prod0_out) old_n_out = n.output[0] in_shape = model.get_tensor_shape(n.input[0]) out_shape = model.get_tensor_shape(n.output[0]) @@ -555,6 +556,8 @@ class MoveScalarLinearPastInvariants(Transformation): model.set_tensor_shape(n.input[0], in_shape) model.set_tensor_shape(n.output[0], out_shape) model.set_tensor_shape(prod0.output[0], out_shape) + model.set_tensor_datatype(prod0.output[0], scalar_op_odt) + model.set_tensor_datatype(n.output[0], DataType.FLOAT32) graph.node.remove(prod0) graph.node.insert(node_ind - 1, prod0) graph_modified = True -- GitLab