Skip to content
Snippets Groups Projects
Commit 4d0ceb1a authored by AndreaRigoni's avatar AndreaRigoni
Browse files

[fix] batchnorm transform always removes producer and consumer nodes

parent fa421928
No related branches found
No related tags found
No related merge requests found
......@@ -67,8 +67,10 @@ class BatchNormToAffine(Transformation):
# remove old nodes
graph.node.remove(n)
if consumer is not None:
graph.node.remove(consumer)
if consumer.op_type == "Squeeze":
graph.node.remove(consumer)
if producer is not None:
graph.node.remove(producer)
if producer.op_type == "Unsqueeze":
graph.node.remove(producer)
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment