diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py
index 4f3984cd75818d75e6809d7300383d8e4d342dc4..646c69ef066280554b74b2f56afae76a9eab4326 100644
--- a/src/finn/transformation/batchnorm_to_affine.py
+++ b/src/finn/transformation/batchnorm_to_affine.py
@@ -1,8 +1,9 @@
 import numpy as np
-import onnx.shape_inference as si
 from onnx import TensorProto
 from onnx import helper as oh
 
+import finn.transformation.infer_shapes as si
+
 
 def batchnorm_to_affine(model):
     """Replaces any test-time BatchNorm layers with Mul-Add layers."""
@@ -68,5 +69,5 @@ def batchnorm_to_affine(model):
     for n in nodes_to_remove:
         graph.node.remove(n)
         graph_modified = True
-    model.model = si.infer_shapes(model.model)
+    model = model.transform_single(si.infer_shapes)
     return (model, graph_modified)