From 58b980379f248040fa5e9660f15b758c6ac1e379 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 10 Jul 2020 16:13:58 +0100
Subject: [PATCH] [Transform] preserve scalar odt in
 MoveScalarLinearPastInvariants

---
 src/finn/transformation/streamline/reorder.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 452366607..b47f269dd 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -545,6 +545,7 @@ class MoveScalarLinearPastInvariants(Transformation):
                     # move prod0 from input to output,
                     old_prod0_in = prod0.input[0]
                     old_prod0_out = prod0.output[0]
+                    scalar_op_odt = model.get_tensor_datatype(old_prod0_out)
                     old_n_out = n.output[0]
                     in_shape = model.get_tensor_shape(n.input[0])
                     out_shape = model.get_tensor_shape(n.output[0])
@@ -555,6 +556,8 @@ class MoveScalarLinearPastInvariants(Transformation):
                     model.set_tensor_shape(n.input[0], in_shape)
                     model.set_tensor_shape(n.output[0], out_shape)
                     model.set_tensor_shape(prod0.output[0], out_shape)
+                    model.set_tensor_datatype(prod0.output[0], scalar_op_odt)
+                    model.set_tensor_datatype(n.output[0], DataType.FLOAT32)
                     graph.node.remove(prod0)
                     graph.node.insert(node_ind - 1, prod0)
                     graph_modified = True
-- 
GitLab