diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index 412fddaed15721b36e6fbc0e778f394c6343255d..12fd14646728d3d629d7dffd87b23f4d4dd256e6 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -1,5 +1,8 @@
 import copy
 
+import numpy as np
+from onnx import TensorProto
+from onnx import helper as oh
 from onnx import numpy_helper as np_helper
 
 
@@ -77,3 +80,67 @@ def make_new_valueinfo_name(model):
     while candidate in names:
         candidate = str(int(candidate) + 1)
     return candidate
+
+
+def replace_batchnorm_with_affine(model):
+    """Replaces any test-time BatchNorm layers with Mul-Add layers."""
+    new_model = copy.deepcopy(model)
+    graph = new_model.graph
+    nodes_to_remove = []
+    for n in graph.node:
+        if n.op_type == "BatchNormalization":
+            bn_input = n.input[0]
+            bn_output = n.output[0]
+            # extract batchnorm parameters as numpy arrays
+            scale = get_initializer(n.input[1])
+            bias = get_initializer(n.input[2])
+            mean = get_initializer(n.input[3])
+            variance = get_initializer(n.input[4])
+            epsilon = 1e-5
+            # find A and B to compute batchnorm as affine transpose Ax+B
+            # TODO is a division by moving avg factor needed for variance?
+            A = scale / np.sqrt(epsilon + variance)
+            B = bias - (A * mean)
+            nodes_to_remove += [n]
+            # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
+            producer = find_producer(n)
+            if producer is not None:
+                if producer.op_type == "Unsqueeze":
+                    bn_input = producer.input[0]
+                    nodes_to_remove += [producer]
+            consumer = find_consumer(n)
+            if consumer is not None:
+                if consumer.op_type == "Squeeze":
+                    bn_output = consumer.output[0]
+                    nodes_to_remove += [consumer]
+            # create value_info and initializers for Mul and Add constants
+            mul_const = oh.make_tensor_value_info(
+                make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
+            )
+            graph.value_info.append(mul_const)
+            set_initializer(new_model, mul_const.name, A)
+            mul_output = oh.make_tensor_value_info(
+                make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
+            )
+            graph.value_info.append(mul_output)
+            add_const = oh.make_tensor_value_info(
+                make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape
+            )
+            graph.value_info.append(add_const)
+            set_initializer(new_model, add_const.name, B)
+            # create Mul and Add nodes to replace the batchnorm
+            mul_node = oh.make_node(
+                "Mul", [bn_input.name, mul_const.name], [mul_output.name]
+            )
+            add_node = oh.make_node(
+                "Add", [mul_output.name, add_const.name], [bn_output.name]
+            )
+            graph.node.append(mul_node)
+            graph.node.append(add_node)
+
+    # delete marked nodes
+    for n in nodes_to_remove:
+        graph.node.remove(n)
+    # TODO topologically sort nodes
+    # TODO give new names, maybe run shape inference?
+    return new_model