diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 452366607943db77d7ef3764c67b0fdd4b0fbb40..b47f269dd6f2671c3d98c9316954483c0e72f14f 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -545,6 +545,7 @@ class MoveScalarLinearPastInvariants(Transformation): # move prod0 from input to output, old_prod0_in = prod0.input[0] old_prod0_out = prod0.output[0] + scalar_op_odt = model.get_tensor_datatype(old_prod0_out) old_n_out = n.output[0] in_shape = model.get_tensor_shape(n.input[0]) out_shape = model.get_tensor_shape(n.output[0]) @@ -555,6 +556,8 @@ class MoveScalarLinearPastInvariants(Transformation): model.set_tensor_shape(n.input[0], in_shape) model.set_tensor_shape(n.output[0], out_shape) model.set_tensor_shape(prod0.output[0], out_shape) + model.set_tensor_datatype(prod0.output[0], scalar_op_odt) + model.set_tensor_datatype(n.output[0], DataType.FLOAT32) graph.node.remove(prod0) graph.node.insert(node_ind - 1, prod0) graph_modified = True