diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index a1bd16f6d0b70193122d5d067ccdee395260c7b1..cc95d34b784b47c9baeb6c1076915db8b1d09d57 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -32,6 +32,7 @@ from onnx import helper as oh
 
 from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_data_layouts import InferDataLayouts
 from finn.core.datatype import DataType
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
@@ -68,8 +69,11 @@ class MoveAddPastMul(Transformation):
                     add_weight_name = n.input[1]
                     A = model.get_initializer(mul_weight_name)
                     B = model.get_initializer(add_weight_name)
-                    assert A is not None, "Initializer for mul weights is not set."
-                    assert B is not None, "Initializer for add weights is not set."
+                    if (A is None) or (B is None):
+                        warnings.warn(
+                            "Mul or add does not have constant params, skipping"
+                        )
+                        continue
                     start_name = n.input[0]
                     middle_name = n.output[0]
                     end_name = consumer.output[0]
@@ -124,8 +128,9 @@ class MoveScalarMulPastMatMul(Transformation):
                     matmul_weight_name = consumer.input[1]
                     A = model.get_initializer(mul_weight_name)
                     W = model.get_initializer(matmul_weight_name)
-                    assert A is not None, "Initializer for mul weights is not set."
-                    assert W is not None, "Initializer for matmul weights is not set."
+                    if (A is None) or (W is None):
+                        warnings.warn("MatMul or Mul params are not constant, skipping")
+                        continue
                     start_name = n.input[0]
                     middle_name = n.output[0]
                     end_name = consumer.output[0]
@@ -181,8 +186,9 @@ class MoveScalarAddPastMatMul(Transformation):
                     matmul_weight_name = consumer.input[1]
                     A = model.get_initializer(add_weight_name)
                     W = model.get_initializer(matmul_weight_name)
-                    assert A is not None, "Initializer for add weights is not set."
-                    assert W is not None, "Initializer for matmul weights is not set."
+                    if (A is None) or (W is None):
+                        warnings.warn("MatMul or Add params are not constant, skipping")
+                        continue
                     start_name = n.input[0]
                     middle_name = n.output[0]
                     end_name = consumer.output[0]
@@ -243,7 +249,9 @@ class MoveScalarAddPastConv(Transformation):
                     conv_in_name = consumer.input[0]
                     conv_in_shape = model.get_tensor_shape(conv_in_name)
                     A = model.get_initializer(add_weight_name)
-                    assert A is not None, "Initializer for add weights is not set."
+                    if A is None:
+                        warnings.warn("Add param is not constant, skipping")
+                        continue
                     start_name = n.input[0]
                     end_name = consumer.output[0]
                     conv_out_shape = model.get_tensor_shape(end_name)
@@ -311,7 +319,9 @@ class MoveScalarMulPastConv(Transformation):
                 ):
                     mul_weight_name = n.input[1]
                     A = model.get_initializer(mul_weight_name)
-                    assert A is not None, "Initializer for mul weights is not set."
+                    if A is None:
+                        warnings.warn("Mul param is not constant, skipping")
+                        continue
                     conv_node = consumer
                     mul_node = n
                     start_name = mul_node.input[0]
@@ -663,3 +673,66 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
 
         model = model.transform(InferShapes())
         return (model, graph_modified)
+
+
+class MoveTransposePastScalarMul(Transformation):
+    """Moves a Transpose node past a scalar Mul node"""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "Transpose"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
+                consumer = model.find_consumer(n.output[0])
+                if (
+                    consumer is not None
+                    and consumer.op_type == "Mul"
+                    and not model.is_join_node(consumer)
+                ):
+                    mul_weight_name = consumer.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    if A is None:
+                        warnings.warn("Mul param is not constant, skipping")
+                        continue
+                    transp_node = n
+                    mul_node = consumer
+                    start_name = transp_node.input[0]
+                    middle_name = transp_node.output[0]
+                    end_name = mul_node.output[0]
+                    transp_in_shape = model.get_tensor_shape(start_name)
+                    transp_out_shape = model.get_tensor_shape(middle_name)
+                    transp_in_layout = model.get_tensor_layout(start_name)
+                    transp_out_layout = model.get_tensor_layout(middle_name)
+                    if transp_in_layout is None or transp_out_layout is None:
+                        warnings.warn(
+                            """Datalayout is not set for tensors.
+                            Transformation can't be applied."""
+                        )
+                        continue
+                    if all(x == 1 for x in A.shape):
+                        # if the mul is scalar, we can simply swap the order of ops
+                        # rewire transpose input to be mul input
+                        mul_node.input[0] = start_name
+                        model.set_tensor_shape(start_name, transp_in_shape)
+                        model.set_tensor_layout(start_name, transp_in_layout)
+                        mul_node.output[0] = middle_name
+                        model.set_tensor_shape(middle_name, transp_in_shape)
+                        model.set_tensor_layout(middle_name, transp_in_layout)
+                        transp_node.input[0] = middle_name
+                        transp_node.output[0] = end_name
+                        model.set_tensor_shape(end_name, transp_out_shape)
+                        model.set_tensor_layout(end_name, transp_out_layout)
+                        graph.node.remove(transp_node)
+                        graph.node.insert(node_ind, transp_node)
+                        graph_modified = True
+
+        if graph_modified is True:
+            model = model.transform(InferDataLayouts())
+            model = model.transform(InferShapes())
+        return (model, graph_modified)
diff --git a/tests/transformation/test_move_transpose_past_scalar_mul.py b/tests/transformation/test_move_transpose_past_scalar_mul.py
new file mode 100644
index 0000000000000000000000000000000000000000..e434fc7d4f683120176e18a2bfa9da99d9ee0b0e
--- /dev/null
+++ b/tests/transformation/test_move_transpose_past_scalar_mul.py
@@ -0,0 +1,82 @@
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+
+from finn.core.modelwrapper import ModelWrapper
+import finn.core.data_layout as DataLayout
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_data_layouts import InferDataLayouts
+from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
+from finn.transformation.streamline.reorder import MoveTransposePastScalarMul
+import finn.core.onnx_exec as oxe
+
+# permutation of transpose node
+@pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 1, 3, 2], [3, 2, 0, 1]])
+# scalar mul
+@pytest.mark.parametrize("scalar", [True, False])
+# data layout
+@pytest.mark.parametrize("data_layout", [None, DataLayout.NHWC, DataLayout.NCHW])
+def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 2, 3, 4])
+    # to determine out_size we need to calculate with "perm" for this test case
+    dummy_in = np.random.uniform(low=0, high=1, size=(1, 2, 3, 4)).astype(np.float32)
+    out_size = dummy_in.transpose(tuple(perm)).shape
+
+    if scalar is True:
+        a0_size = []
+    else:
+        a0_size = out_size
+    a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, a0_size)
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, out_size)
+    transp_node = helper.make_node("Transpose", ["inp"], ["transp_out"], perm=perm)
+    mul_node = helper.make_node("Mul", ["transp_out", "a0"], ["outp"])
+
+    graph = helper.make_graph(
+        nodes=[transp_node, mul_node],
+        name="mv-transpose-graph",
+        inputs=[inp],
+        outputs=[outp],
+        value_info=[a0],
+    )
+
+    model = helper.make_model(graph, producer_name="mv_transpose_model")
+    model = ModelWrapper(model)
+
+    # initialize values
+    a0_values = np.random.uniform(low=0, high=1, size=tuple(a0_size)).astype(np.float32)
+    model.set_initializer("a0", a0_values)
+    if data_layout is not None:
+        model.set_tensor_layout("inp", data_layout)
+        model = model.transform(InferDataLayouts())
+
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
+
+    # compare execution before and after transformation
+    inp_values = np.random.uniform(low=0, high=1, size=(1, 2, 3, 4)).astype(np.float32)
+    idict = {model.graph.input[0].name: inp_values}
+    model_transformed = model.transform(MoveTransposePastScalarMul())
+    assert oxe.compare_execution(model, model_transformed, idict)
+
+    # check if order changed
+    if scalar is True and data_layout is not None:
+        assert model_transformed.graph.node[0] != model.graph.node[0]
+        assert model_transformed.graph.node[1] != model.graph.node[1]
+        assert model_transformed.graph.node[0].op_type == "Mul"
+        assert model_transformed.graph.node[1].op_type == "Transpose"
+        mul_input = model_transformed.graph.node[0].input[0]
+        mul_output = model_transformed.graph.node[0].output[0]
+        assert model_transformed.get_tensor_layout(mul_input) == data_layout
+        assert model_transformed.get_tensor_layout(mul_output) == data_layout
+    else:
+        assert model_transformed.graph.node[0] == model.graph.node[0]
+        assert model_transformed.graph.node[1] == model.graph.node[1]
+        if data_layout is not None:
+            mul_input = model_transformed.graph.node[1].input[0]
+            mul_output = model_transformed.graph.node[1].output[0]
+            assert model_transformed.get_tensor_layout(mul_input) != data_layout
+            assert model_transformed.get_tensor_layout(mul_output) != data_layout