diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py index 4f3984cd75818d75e6809d7300383d8e4d342dc4..646c69ef066280554b74b2f56afae76a9eab4326 100644 --- a/src/finn/transformation/batchnorm_to_affine.py +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -1,8 +1,9 @@ import numpy as np -import onnx.shape_inference as si from onnx import TensorProto from onnx import helper as oh +import finn.transformation.infer_shapes as si + def batchnorm_to_affine(model): """Replaces any test-time BatchNorm layers with Mul-Add layers.""" @@ -68,5 +69,5 @@ def batchnorm_to_affine(model): for n in nodes_to_remove: graph.node.remove(n) graph_modified = True - model.model = si.infer_shapes(model.model) + model = model.transform_single(si.infer_shapes) return (model, graph_modified)