diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fe3e8a131fbade8f6058193f68fe8c28fe28eec
--- /dev/null
+++ b/src/finn/transformation/batchnorm_to_affine.py
@@ -0,0 +1,75 @@
+import copy
+
+import numpy as np
+import onnx.shape_inference as si
+from onnx import TensorProto
+from onnx import helper as oh
+
+import finn.transformation.general as tg
+
+
+def batchnorm_to_affine(model):
+    """Replaces any test-time BatchNorm layers with Mul-Add layers."""
+    new_model = copy.deepcopy(model)
+    new_model = si.infer_shapes(new_model)
+    graph = new_model.graph
+    nodes_to_remove = []
+    node_ind = 0
+    for n in graph.node:
+        node_ind += 1
+        if n.op_type == "BatchNormalization":
+            bn_input = n.input[0]
+            bn_output = n.output[0]
+            # extract batchnorm parameters as numpy arrays
+            scale = tg.get_initializer(new_model, n.input[1])
+            bias = tg.get_initializer(new_model, n.input[2])
+            mean = tg.get_initializer(new_model, n.input[3])
+            variance = tg.get_initializer(new_model, n.input[4])
+            epsilon = 1e-5
+            # find A and B to compute batchnorm as affine transpose Ax+B
+            # TODO is a division by moving avg factor needed for variance?
+            A = scale / np.sqrt(epsilon + variance)
+            B = bias - (A * mean)
+            nodes_to_remove += [n]
+            # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
+            producer = tg.find_producer(new_model, bn_input)
+            if producer is not None:
+                if producer.op_type == "Unsqueeze":
+                    bn_input = producer.input[0]
+                    nodes_to_remove += [producer]
+            consumer = tg.find_consumer(new_model, bn_output)
+            if consumer is not None:
+                if consumer.op_type == "Squeeze":
+                    bn_output = consumer.output[0]
+                    nodes_to_remove += [consumer]
+            data_shape = tg.get_tensor_shape(new_model, bn_input)
+            # create value_info and initializers for Mul and Add constants
+            mul_const = oh.make_tensor_value_info(
+                tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
+            )
+            graph.value_info.append(mul_const)
+            tg.set_initializer(new_model, mul_const.name, A)
+            mul_output = oh.make_tensor_value_info(
+                tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape
+            )
+            graph.value_info.append(mul_output)
+            add_const = oh.make_tensor_value_info(
+                tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape
+            )
+            graph.value_info.append(add_const)
+            tg.set_initializer(new_model, add_const.name, B)
+            # create Mul and Add nodes to replace the batchnorm
+            mul_node = oh.make_node(
+                "Mul", [bn_input, mul_const.name], [mul_output.name]
+            )
+            add_node = oh.make_node(
+                "Add", [mul_output.name, add_const.name], [bn_output]
+            )
+            # insert where the batchnorm is to preserve topological ordering
+            graph.node.insert(node_ind, mul_node)
+            graph.node.insert(node_ind + 1, add_node)
+    # delete marked nodes (batchnorm and (un)squeezing)
+    for n in nodes_to_remove:
+        graph.node.remove(n)
+    new_model = si.infer_shapes(new_model)
+    return new_model
diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index db01b00914db84cf31f5f6420c05e0032141669b..9086823113ec1cfe90431ce8ee904105f6698999 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -1,10 +1,6 @@
 import copy
 
-import numpy as np
-import onnx.shape_inference as si
-from onnx import TensorProto
-from onnx import helper as oh
-from onnx import numpy_helper as np_helper
+import onnx.numpy_helper as np_helper
 
 
 def give_unique_names(model):
@@ -93,70 +89,3 @@ def make_new_valueinfo_name(model):
     while candidate in names:
         candidate = str(int(candidate) + 1)
     return candidate
-
-
-def replace_batchnorm_with_affine(model):
-    """Replaces any test-time BatchNorm layers with Mul-Add layers."""
-    new_model = copy.deepcopy(model)
-    new_model = si.infer_shapes(new_model)
-    graph = new_model.graph
-    nodes_to_remove = []
-    node_ind = 0
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "BatchNormalization":
-            bn_input = n.input[0]
-            bn_output = n.output[0]
-            # extract batchnorm parameters as numpy arrays
-            scale = get_initializer(new_model, n.input[1])
-            bias = get_initializer(new_model, n.input[2])
-            mean = get_initializer(new_model, n.input[3])
-            variance = get_initializer(new_model, n.input[4])
-            epsilon = 1e-5
-            # find A and B to compute batchnorm as affine transpose Ax+B
-            # TODO is a division by moving avg factor needed for variance?
-            A = scale / np.sqrt(epsilon + variance)
-            B = bias - (A * mean)
-            nodes_to_remove += [n]
-            # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
-            producer = find_producer(new_model, bn_input)
-            if producer is not None:
-                if producer.op_type == "Unsqueeze":
-                    bn_input = producer.input[0]
-                    nodes_to_remove += [producer]
-            consumer = find_consumer(new_model, bn_output)
-            if consumer is not None:
-                if consumer.op_type == "Squeeze":
-                    bn_output = consumer.output[0]
-                    nodes_to_remove += [consumer]
-            data_shape = get_tensor_shape(new_model, bn_input)
-            # create value_info and initializers for Mul and Add constants
-            mul_const = oh.make_tensor_value_info(
-                make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
-            )
-            graph.value_info.append(mul_const)
-            set_initializer(new_model, mul_const.name, A)
-            mul_output = oh.make_tensor_value_info(
-                make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape
-            )
-            graph.value_info.append(mul_output)
-            add_const = oh.make_tensor_value_info(
-                make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape
-            )
-            graph.value_info.append(add_const)
-            set_initializer(new_model, add_const.name, B)
-            # create Mul and Add nodes to replace the batchnorm
-            mul_node = oh.make_node(
-                "Mul", [bn_input, mul_const.name], [mul_output.name]
-            )
-            add_node = oh.make_node(
-                "Add", [mul_output.name, add_const.name], [bn_output]
-            )
-            # insert where the batchnorm is to preserve topological ordering
-            graph.node.insert(node_ind, mul_node)
-            graph.node.insert(node_ind + 1, add_node)
-    # delete marked nodes (batchnorm and (un)squeezing)
-    for n in nodes_to_remove:
-        graph.node.remove(n)
-    new_model = si.infer_shapes(new_model)
-    return new_model
diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py
index babe1c2be373620fa3966b2ed923fe13bc2ecd89..79f6357bbc2127468b57e86b76f6eb28e089514e 100644
--- a/tests/test_batchnorm_to_affine.py
+++ b/tests/test_batchnorm_to_affine.py
@@ -13,7 +13,7 @@ from models.common import get_act_quant, get_quant_linear, get_quant_type, get_s
 from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.general as tx
+import finn.transformation.batchnorm_to_affine as tx
 
 FC_OUT_FEATURES = [1024, 1024, 1024]
 INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
@@ -94,7 +94,7 @@ def test_batchnorm_to_affine():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = onnx.load(export_onnx_path)
-    new_model = tx.replace_batchnorm_with_affine(model)
+    new_model = tx.batchnorm_to_affine(model)
     try:
         os.remove("/tmp/" + mnist_onnx_filename)
     except OSError: