diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py
new file mode 100644
index 0000000000000000000000000000000000000000..babe1c2be373620fa3966b2ed923fe13bc2ecd89
--- /dev/null
+++ b/tests/test_batchnorm_to_affine.py
@@ -0,0 +1,115 @@
+import os
+import shutil
+from functools import reduce
+from operator import mul
+
+import brevitas.onnx as bo
+import numpy as np
+import onnx
+import onnx.numpy_helper as nph
+import torch
+import wget
+from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
+from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
+
+import finn.core.onnx_exec as oxe
+import finn.transformation.general as tx
+
+FC_OUT_FEATURES = [1024, 1024, 1024]
+INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
+LAST_FC_PER_OUT_CH_SCALING = False
+IN_DROPOUT = 0.2
+HIDDEN_DROPOUT = 0.2
+
+mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
+mnist_onnx_filename = "mnist.tar.gz"
+mnist_onnx_local_dir = "/tmp/mnist_onnx"
+export_onnx_path = "test_output_lfc.onnx"
+transformed_onnx_path = "test_output_lfc_transformed.onnx"
+# TODO get from config instead, hardcoded to Docker path for now
+trained_lfc_checkpoint = (
+    "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
+)
+
+
+class LFC(Module):
+    def __init__(
+        self,
+        num_classes=10,
+        weight_bit_width=None,
+        act_bit_width=None,
+        in_bit_width=None,
+        in_ch=1,
+        in_features=(28, 28),
+    ):
+        super(LFC, self).__init__()
+
+        weight_quant_type = get_quant_type(weight_bit_width)
+        act_quant_type = get_quant_type(act_bit_width)
+        in_quant_type = get_quant_type(in_bit_width)
+        stats_op = get_stats_op(weight_quant_type)
+
+        self.features = ModuleList()
+        self.features.append(get_act_quant(in_bit_width, in_quant_type))
+        self.features.append(Dropout(p=IN_DROPOUT))
+        in_features = reduce(mul, in_features)
+        for out_features in FC_OUT_FEATURES:
+            self.features.append(
+                get_quant_linear(
+                    in_features=in_features,
+                    out_features=out_features,
+                    per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
+                    bit_width=weight_bit_width,
+                    quant_type=weight_quant_type,
+                    stats_op=stats_op,
+                )
+            )
+            in_features = out_features
+            self.features.append(BatchNorm1d(num_features=in_features))
+            self.features.append(get_act_quant(act_bit_width, act_quant_type))
+            self.features.append(Dropout(p=HIDDEN_DROPOUT))
+        self.fc = get_quant_linear(
+            in_features=in_features,
+            out_features=num_classes,
+            per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
+            bit_width=weight_bit_width,
+            quant_type=weight_quant_type,
+            stats_op=stats_op,
+        )
+
+    def forward(self, x):
+        x = x.view(1, 784)
+        # removing the torch.tensor here creates a float64 op for some reason..
+        # so explicitly wrapped with torch.tensor to make a float32 one instead
+        x = 2.0 * x - torch.tensor([1.0])
+        for mod in self.features:
+            x = mod(x)
+        out = self.fc(x)
+        return out
+
+
+def test_batchnorm_to_affine():
+    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
+    checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
+    lfc.load_state_dict(checkpoint["state_dict"])
+    bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
+    model = onnx.load(export_onnx_path)
+    new_model = tx.replace_batchnorm_with_affine(model)
+    try:
+        os.remove("/tmp/" + mnist_onnx_filename)
+    except OSError:
+        pass
+    dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
+    shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
+    # load one of the test vectors
+    input_tensor = onnx.TensorProto()
+    with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
+        input_tensor.ParseFromString(f.read())
+    input_dict = {"0": nph.to_array(input_tensor)}
+    output_original = oxe.execute_onnx(model, input_dict)["53"]
+    output_transformed = oxe.execute_onnx(new_model, input_dict)["53"]
+    assert np.isclose(output_transformed, output_original, atol=1e-3).all()
+    # remove the downloaded model and extracted files
+    os.remove(dl_ret)
+    shutil.rmtree(mnist_onnx_local_dir)
+    os.remove(export_onnx_path)