diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 2b03532ce3ba7d5159e5ae57e61c2af9c8c37fce..b47f269dd6f2671c3d98c9316954483c0e72f14f 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -502,6 +502,73 @@ class MoveLinearPastEltwiseAdd(Transformation):
         return (model, graph_modified)
 
 
+class MoveScalarLinearPastInvariants(Transformation):
+    """Move scalar linear operations (mul, add) past functions which are invariant
+       to them. Specifically, matches and transforms the following patterns:
+       f(x*C) -> f(x) * C
+       f(x+C) -> f(x) + C
+       where x is a dynamic input, C is a constant tensor.
+       Known f which obey this property are: Reshape, Flatten, Transpose,
+       GlobalAveragePool
+    """
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        nodes = [n for n in graph.node]
+        for n in nodes:
+            node_ind += 1
+            if (
+                n.op_type == "GlobalAveragePool"
+                or n.op_type == "Reshape"
+                or n.op_type == "Transpose"
+                or n.op_type == "Flatten"
+            ):
+                in0 = n.input[0]
+                if in0 is None:
+                    continue
+                # find and check producer on our input
+                prod0 = model.find_producer(in0)
+                if prod0 is None:
+                    continue
+
+                if prod0.op_type == "Mul" or prod0.op_type == "Add":
+                    # check if second input of producer is an initializer
+                    init0 = model.get_initializer(prod0.input[1])
+                    # if either initializer is None, skip
+                    if init0 is None:
+                        continue
+                    # if initializer is not scalar, skip
+                    if np.prod(init0.shape) != 1:
+                        continue
+                    # move prod0 from input to output,
+                    old_prod0_in = prod0.input[0]
+                    old_prod0_out = prod0.output[0]
+                    scalar_op_odt = model.get_tensor_datatype(old_prod0_out)
+                    old_n_out = n.output[0]
+                    in_shape = model.get_tensor_shape(n.input[0])
+                    out_shape = model.get_tensor_shape(n.output[0])
+                    n.input[0] = old_prod0_in
+                    n.output[0] = old_prod0_out
+                    prod0.input[0] = old_prod0_out
+                    prod0.output[0] = old_n_out
+                    model.set_tensor_shape(n.input[0], in_shape)
+                    model.set_tensor_shape(n.output[0], out_shape)
+                    model.set_tensor_shape(prod0.output[0], out_shape)
+                    model.set_tensor_datatype(prod0.output[0], scalar_op_odt)
+                    model.set_tensor_datatype(n.output[0], DataType.FLOAT32)
+                    graph.node.remove(prod0)
+                    graph.node.insert(node_ind - 1, prod0)
+                    graph_modified = True
+                else:
+                    continue
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+
+
 class MakeMaxPoolNHWC(Transformation):
     """Convert (MaxPool, NHWCTranpose) into (MaxPoolNHWC)."""
 
@@ -685,6 +752,7 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
         model = model.transform(InferShapes())
         return (model, graph_modified)
 
+
 class MoveFlattenPastTopK(Transformation):
     """Move flatten node past a succeeding topk node, if the "axis" attribute in topk
     is set to -1 and the data layout before the flatten is NHWC with H=W=1"""
@@ -745,6 +813,7 @@ class MoveFlattenPastTopK(Transformation):
         model = model.transform(InferShapes())
         return (model, graph_modified)
 
+
 class MoveFlattenPastAffine(Transformation):
     """Moves a node that implements a (1, -1) reshape past a MatMul, Mul or Add node."""
 
@@ -831,9 +900,10 @@ class MoveFlattenPastAffine(Transformation):
 
         model = model.transform(InferShapes())
         model = model.transform(InferDataTypes())
-        model = model.transform(InferDataLayouts())                  
+        model = model.transform(InferDataLayouts())
         return (model, graph_modified)
-      
+
+
 class MoveTransposePastScalarMul(Transformation):
     """Moves a Transpose node past a scalar Mul node"""
 
@@ -895,4 +965,3 @@ class MoveTransposePastScalarMul(Transformation):
             model = model.transform(InferDataLayouts())
             model = model.transform(InferShapes())
         return (model, graph_modified)
-