diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 23f844cd9d794d6b9aa4d8fb926f2f812791a9ba..7d90df60382791435d310d7683a0e5e3f33ef63c 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -1,3 +1,4 @@
+import numpy as np
 from onnx import helper as oh
 
 import finn.transformation.infer_shapes as si
@@ -114,10 +115,6 @@ def move_scalar_mul_past_matmul(model):
         if n.op_type == "Mul":
             consumer = model.find_consumer(n.output[0])
             if consumer is not None and consumer.op_type == "MatMul":
-                # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
-                # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
-                # assume input 0 is from the previous layer, input 1 is the
-                # trained (constant) parameter
                 mul_weight_name = n.input[1]
                 matmul_weight_name = consumer.input[1]
                 A = model.get_initializer(mul_weight_name)
@@ -147,3 +144,45 @@ def move_scalar_mul_past_matmul(model):
         graph_modified = True
     model = model.transform_single(si.infer_shapes)
     return (model, graph_modified)
+
+
+def move_scalar_add_past_matmul(model):
+    """Move scalar add operations past matmul operations. We want to have adds
+    next to each other such that they can be collapsed into a single add."""
+    graph = model.graph
+    node_ind = 0
+    graph_modified = False
+    for n in graph.node:
+        node_ind += 1
+        if n.op_type == "Add":
+            consumer = model.find_consumer(n.output[0])
+            if consumer is not None and consumer.op_type == "MatMul":
+                add_weight_name = n.input[1]
+                matmul_weight_name = consumer.input[1]
+                A = model.get_initializer(add_weight_name)
+                W = model.get_initializer(matmul_weight_name)
+                assert A is not None
+                assert W is not None
+                start_name = n.input[0]
+                middle_name = n.output[0]
+                end_name = consumer.output[0]
+                if all(x == 1 for x in A.shape):
+                    # if the add is scalar, we can move it past the matmul
+                    # by taking it past the matmul with a dot product
+                    Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W)
+                    # update the add weight
+                    model.set_initializer(add_weight_name, Anew)
+                    new_matmul = oh.make_node(
+                        "MatMul", [start_name, matmul_weight_name], [middle_name]
+                    )
+                    new_add = oh.make_node(
+                        "Add", [middle_name, add_weight_name], [end_name]
+                    )
+                    graph.node.insert(node_ind, new_matmul)
+                    graph.node.insert(node_ind + 1, new_add)
+                    # remove old nodes
+                    graph.node.remove(n)
+                    graph.node.remove(consumer)
+                    graph_modified = True
+    model = model.transform_single(si.infer_shapes)
+    return (model, graph_modified)