diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index 7c78acbd3cd463f7e375b65646e864ca11f14d22..393061219d60e46b5cce5cc017d0f77f454ee7aa 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -33,6 +33,23 @@ class ModelWrapper: def graph(self, value): self._model_proto.graph = value + def transform_repeated(self, transform): + """Applies given transform repeatedly until no more changes can be made + and returns a transformed ModelWrapper instance. + Transform must return (transformed_model, model_was_changed).""" + transformed_model = self + model_was_changed = True + while model_was_changed: + (transformed_model, model_was_changed) = transform(transformed_model) + return transformed_model + + def transform_single(self, transform): + """Applies given transform once and returns transformed ModelWrapper + instance. Transform must return (transformed_model, model_was_changed), + although model_was_changed is ignored (see also apply_repeated).""" + (transformed_model, model_was_changed) = transform(self) + return transformed_model + def check_compatibility(self): """Checks this model for FINN compatibility: * no embedded subgraphs