diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 62301cee54adb7c20af901cd75ee366600c50f80..748c7420a8abf0047f1b2bd059d8ca6622e12bce 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -32,6 +32,7 @@ from onnx import helper as oh
 
 from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_data_layouts import InferDataLayouts
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
@@ -68,7 +69,9 @@ class MoveAddPastMul(Transformation):
                     A = model.get_initializer(mul_weight_name)
                     B = model.get_initializer(add_weight_name)
                     if (A is None) or (B is None):
-                        warnings.warn("Mul or add does not have constant params, skipping")
+                        warnings.warn(
+                            "Mul or add does not have constant params, skipping"
+                        )
                         continue
                     start_name = n.input[0]
                     middle_name = n.output[0]
@@ -638,18 +641,24 @@ class MoveTransposePastScalarMul(Transformation):
                     end_name = mul_node.output[0]
                     transp_in_shape = model.get_tensor_shape(start_name)
                     transp_out_shape = model.get_tensor_shape(middle_name)
+                    transp_in_layout = model.get_tensor_layout(start_name)
+                    transp_out_layout = model.get_tensor_layout(middle_name)
                     if all(x == 1 for x in A.shape):
                         # if the mul is scalar, we can simply swap the order of ops
                         # rewire transpose input to be mul input
                         mul_node.input[0] = start_name
                         model.set_tensor_shape(start_name, transp_in_shape)
+                        model.set_tensor_layout(start_name, transp_in_layout)
                         mul_node.output[0] = middle_name
                         model.set_tensor_shape(middle_name, transp_in_shape)
+                        model.set_tensor_layout(middle_name, transp_in_layout)
                         transp_node.input[0] = middle_name
                         transp_node.output[0] = end_name
                         model.set_tensor_shape(end_name, transp_out_shape)
+                        model.set_tensor_layout(end_name, transp_out_layout)
                         graph.node.remove(transp_node)
                         graph.node.insert(node_ind, transp_node)
                         graph_modified = True
+        model = model.transform(InferDataLayouts())
         model = model.transform(InferShapes())
         return (model, graph_modified)