diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py index 655ddd9842f37d59155aa0b12edeffecd89d65c1..aea6b83c737cbd79f38c8b847d0b7598c72e4f69 100644 --- a/src/finn/transformation/batchnorm_to_affine.py +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -67,8 +67,10 @@ class BatchNormToAffine(Transformation): # remove old nodes graph.node.remove(n) if consumer is not None: - graph.node.remove(consumer) + if consumer.op_type == "Squeeze": + graph.node.remove(consumer) if producer is not None: - graph.node.remove(producer) + if producer.op_type == "Unsqueeze": + graph.node.remove(producer) model = model.transform(InferShapes()) return (model, graph_modified)