diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 3d025c97c70dbbfb8eb2a4b0ce68277983a37997..39684af25ce02f46b57b3aac053a3e05b2e1ddf0 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -1,3 +1,5 @@
+from onnx import helper as oh
+
 import finn.transformation.infer_shapes as si
 
 
@@ -12,16 +14,36 @@ def move_add_past_mul(model):
         node_ind += 1
         if n.op_type == "Add":
             consumer = model.find_consumer(n.output[0])
-            if consumer.op_type == "Mul":
+            if consumer is not None and consumer.op_type == "Mul":
+                # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
+                # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
                 # assume input 0 is from the previous layer, input 1 is the
                 # trained (constant) parameter
-                mul_param = model.get_initializer(consumer.input[1])
-                add_param = model.get_initializer(n.input[1])
-                assert mul_param is not None
-                assert add_param is not None
-                # TODO compute new param values
-                # TODO make new nodes
-                # TODO mark nodes for removal
+                mul_weight_name = consumer.input[1]
+                add_weight_name = n.input[1]
+                A = model.get_initializer(mul_weight_name)
+                B = model.get_initializer(add_weight_name)
+                assert A is not None
+                assert B is not None
+                start_name = n.input[0]
+                middle_name = n.output[0]
+                end_name = consumer.output[0]
+                # compute new param value for add
+                BA = B * A
+                # make and insert new nodes
+                new_mul = oh.make_node(
+                    "Mul", [start_name, mul_weight_name], [middle_name]
+                )
+                new_add = oh.make_node(
+                    "Add", [middle_name, add_weight_name], [end_name]
+                )
+                graph.node.insert(node_ind, new_mul)
+                graph.node.insert(node_ind + 1, new_add)
+                # replace add value
+                model.set_initializer(add_weight_name, BA)
+                # mark old nodes for removal
+                nodes_to_remove += [n, consumer]
+                graph_modified = True
     # delete marked nodes (batchnorm and (un)squeezing)
     for n in nodes_to_remove:
         graph.node.remove(n)