diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 1fd9ce5108bbe9c317f180680febfc088072b98c..96046602efb32a9262a4cf0bbb21a8367d719910 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -34,6 +34,8 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
 
+def is_scalar(x):
+    return np.prod(x.shape) == 1
 
 class MoveAddPastMul(Transformation):
     """Move add operations past multiply operations. The aim is to have them
@@ -271,6 +273,83 @@ class MoveScalarMulPastConv(Transformation):
         return (model, graph_modified)
 
 
+class MoveScalarLinearPastEltwiseAdd(Transformation):
+    """Move scalar linear operations (mul, add) past elementwise add operations where possible. Specifically,
+       matches and transforms the following patterns:
+       (x*C) + (y*C) -> (x + y) * C
+       (x+A) + (y+B) -> (x + y) + (A + B)
+       where x and y are dynamic inputs, A, B, C are constants.
+    """
+
+    def move_node(self, graph, n, prod0, prod1, node_ind):
+        # found! move one of the muls to output, remove the other one
+        lin0_in0 = prod0.input[0]
+        lin1_in0 = prod1.input[0]
+        in0 = n.input[0]
+        out = n.output[0]
+        # TODO: check shapes don't change through scalar mul or add
+        # connect the eltwise add inputs to mul inputs
+        n.input[0] = lin0_in0
+        n.input[1] = lin1_in0
+        # connect mul0 output to eltwise add output
+        prod0.output[0] = out
+        # connect the input of mul0 and output of eltwise add together
+        n.output[0] = in0
+        prod0.input[0] = in0
+        # move prod0 node past eltwise add node, and remove prod1
+        graph.node.remove(prod1)
+        graph.node.remove(prod0)
+        graph.node.insert(node_ind - 2, prod0)
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Add":
+                # check for tensors on both inputs (eltwise add)
+                # scalar add has an initializer on one input
+                in0 = n.input[0]
+                in1 = n.input[1]
+                if in0 is None or in1 is None:
+                    continue
+                A = model.get_initializer(in0)
+                B = model.get_initializer(in1)
+                if A is not None or B is not None:
+                    continue
+                # check for mul with same initializer on both inputs
+                prod0 = model.find_producer(in0)
+                prod1 = model.find_producer(in1)
+                if prod0 is None or prod1 is None:
+                    continue
+                init0 = model.get_initializer(prod0.input[1])
+                init1 = model.get_initializer(prod1.input[1])
+                # if either initializer is None, skip
+                if init0 is None or init1 is None:
+                    continue
+                # if either initializer is non-scalar, skip
+                # TODO relax this to 1D tensors?
+                if (not is_scalar(init0)) or (not is_scalar(init1)):
+                    continue
+                if prod0.op_type == "Mul" and prod1.op_type == "Mul":
+                    if np.array_equal(init0, init1):
+                        self.move_node(graph, n, prod0, prod1, node_ind)
+                        node_ind -= 1
+                        graph_modified = True
+                elif prod0.op_type == "Add" and prod1.op_type == "Add":
+                    init = init0 + init1
+                    # update initializer of prod0, which we'll move
+                    model.set_initializer(prod0.input[1], init)
+                    self.move_node(graph, n, prod0, prod1, node_ind)
+                    node_ind -= 1
+                    graph_modified = True
+                else:
+                    continue
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
 class MakeMaxPoolNHWC(Transformation):
     """Convert (MaxPool, NHWCTranpose) into (MaxPoolNHWC)."""
 
diff --git a/tests/transformation/test_scalar_past_eltwise.py b/tests/transformation/test_scalar_past_eltwise.py
new file mode 100644
index 0000000000000000000000000000000000000000..e845f32176a9293046b297b7d9e2ab64fabc1791
--- /dev/null
+++ b/tests/transformation/test_scalar_past_eltwise.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import numpy as np
+
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.streamline.reorder import MoveScalarLinearPastEltwiseAdd
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.double_to_single_float import DoubleToSingleFloat
+
+import pytest
+
+export_onnx_path = "test_scalar_past_eltwise.onnx"
+
+# construct a synthetic graph to test:
+# topk insertion, topk conversion to hls, add conversion to hls
+# graph should just be a sum
+
+
+def make_model(shape):
+    inp1 = helper.make_tensor_value_info("inp1", TensorProto.FLOAT, shape)
+    inp2 = helper.make_tensor_value_info("inp2", TensorProto.FLOAT, shape)
+    inp1_add = helper.make_tensor_value_info("inp1_add", TensorProto.FLOAT, shape)
+    inp1_add_ct = helper.make_tensor_value_info("inp1_add_ct", TensorProto.FLOAT, [1])
+    inp2_add = helper.make_tensor_value_info("inp2_add", TensorProto.FLOAT, shape)
+    inp2_add_ct = helper.make_tensor_value_info("inp2_add_ct", TensorProto.FLOAT, [1])
+    inp1_mul = helper.make_tensor_value_info("inp1_mul", TensorProto.FLOAT, shape)
+    inp1_mul_ct = helper.make_tensor_value_info("inp1_mul_ct", TensorProto.FLOAT, [1])
+    inp2_mul = helper.make_tensor_value_info("inp2_mul", TensorProto.FLOAT, shape)
+    inp2_mul_ct = helper.make_tensor_value_info("inp2_mul_ct", TensorProto.FLOAT, [1])
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, shape)
+
+    add1_node = helper.make_node("Add", [inp1.name, inp1_add_ct.name], [inp1_add.name])
+    add2_node = helper.make_node("Add", [inp2.name, inp2_add_ct.name], [inp2_add.name])
+    mul1_node = helper.make_node(
+        "Mul", [inp1_add.name, inp1_mul_ct.name], [inp1_mul.name]
+    )
+    mul2_node = helper.make_node(
+        "Mul", [inp2_add.name, inp2_mul_ct.name], [inp2_mul.name]
+    )
+    eltwise_add_node = helper.make_node(
+        "Add", [inp1_mul.name, inp2_mul.name], [outp.name]
+    )
+    graph = helper.make_graph(
+        nodes=[add1_node, add2_node, mul1_node, mul2_node, eltwise_add_node],
+        name="graph",
+        inputs=[inp1, inp2],
+        outputs=[outp],
+    )
+
+    model = helper.make_model(graph, producer_name="add-model")
+    model = ModelWrapper(model)
+
+    # set initializers for scalar add/mul nodes
+    model.set_initializer(add1_node.input[1], np.array([7.0]))
+    model.set_initializer(add2_node.input[1], np.array([8.0]))
+    model.set_initializer(mul1_node.input[1], np.array([3.0]))
+    model.set_initializer(mul2_node.input[1], np.array([3.0]))
+
+    return model
+
+
+# channels
+@pytest.mark.parametrize("ch", [64])
+# ifmdim
+@pytest.mark.parametrize("ifmdim", [-1, 7])
+def test_scalar_past_eltwise(ch, ifmdim):
+    # generate test vectors of correct shape
+    if ifmdim == -1:
+        input_tensor_shape = (1, ch)
+    else:
+        input_tensor_shape = (1, ch, ifmdim, ifmdim)
+
+    model = make_model(input_tensor_shape)
+    model.save(export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(DoubleToSingleFloat())
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
+
+    x1 = np.random.randn(*input_tensor_shape).astype(np.float32)
+    x2 = np.random.randn(*input_tensor_shape).astype(np.float32)
+
+    # generate expected value from streamlined net
+    input_dict = {model.graph.input[0].name: x1, model.graph.input[1].name: x2}
+
+    output_dict = oxe.execute_onnx(model, input_dict, True)
+    produced_sum = output_dict[model.graph.output[0].name]
+    expected_sum = 3.0 * ((x1 + x2) + 15.0)
+    assert np.isclose(expected_sum, produced_sum, atol=1e-3).all()
+    assert len(model.get_nodes_by_op_type("Add")) == 3
+    assert len(model.get_nodes_by_op_type("Mul")) == 2
+
+    model = model.transform(MoveScalarLinearPastEltwiseAdd())
+
+    # verify again, to check we didnt break anything
+    output_dict = oxe.execute_onnx(model, input_dict, True)
+    produced_sum = output_dict[model.graph.output[0].name]
+    assert np.isclose(expected_sum, produced_sum, atol=1e-3).all()
+    assert len(model.get_nodes_by_op_type("Add")) == 2
+    assert len(model.get_nodes_by_op_type("Mul")) == 1
+
+    os.remove(export_onnx_path)