diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..1f215ec3ff899ead5e2b9f58adbcddb5522b2288 --- /dev/null +++ b/src/finn/transformation/infer_shapes.py @@ -0,0 +1,11 @@ +import onnx.shape_inference as si + + +def infer_shapes(model): + """Ensure every tensor in the model has a specified shape (ValueInfo).""" + # currently just calls ONNX shape inference, but in the future we will + # have to handle shape inference for custom ops ourselves + model.model = si.infer_shapes(model.model) + # single-step operation, no need to call multiple times so return + # model_was_changed = false + return (model, False)