diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py
index 655ddd9842f37d59155aa0b12edeffecd89d65c1..aea6b83c737cbd79f38c8b847d0b7598c72e4f69 100644
--- a/src/finn/transformation/batchnorm_to_affine.py
+++ b/src/finn/transformation/batchnorm_to_affine.py
@@ -67,8 +67,10 @@ class BatchNormToAffine(Transformation):
                 # remove old nodes
                 graph.node.remove(n)
                 if consumer is not None:
-                    graph.node.remove(consumer)
+                    if consumer.op_type == "Squeeze":
+                        graph.node.remove(consumer)
                 if producer is not None:
-                    graph.node.remove(producer)
+                    if producer.op_type == "Unsqueeze":
+                        graph.node.remove(producer)
         model = model.transform(InferShapes())
         return (model, graph_modified)