diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index 646add188c5d475cf37ccd33cf24d29d61754ae1..98b234592ebe0c704fafd1eed980325d8566e7e2 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -36,6 +36,11 @@ from onnx import TensorProto
 import finn.util.basic as util
 import finn.util.onnx as onnxutil
 from finn.core.datatype import DataType
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    SortGraph,
+)
 
 
 class ModelWrapper:
@@ -87,7 +92,7 @@ class ModelWrapper:
         """Runs given anaylsis_fxn on this model and return resulting dict."""
         return analysis_fxn(self)
 
-    def transform(self, transformation, make_deepcopy=True):
+    def transform(self, transformation, make_deepcopy=True, cleanup=True):
         """Applies given Transformation repeatedly until no more changes can be made
         and returns a transformed ModelWrapper instance.
 
@@ -101,6 +106,22 @@ class ModelWrapper:
             (transformed_model, model_was_changed) = transformation.apply(
                 transformed_model
             )
+        if cleanup:
+            transformed_model.cleanup()
+        return transformed_model
+
+    def cleanup(self):
+        "Run cleanup transformations on the model."
+        transformed_model = self
+        cleanup_transforms = [
+            RemoveUnusedTensors(),
+            RemoveStaticGraphInputs(),
+            SortGraph(),
+        ]
+        for trn in cleanup_transforms:
+            transformed_model = transformed_model.transform(
+                trn, cleanup=False, make_deepcopy=False
+            )
         return transformed_model
 
     def check_compatibility(self):
diff --git a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
index 5ec4ab14d65d63523856a6bb107bf75c1ca5a261..fb8b4358abd772d13c355f797649dc3b51975b4d 100644
--- a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
+++ b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
@@ -112,6 +112,7 @@ class CreateDataflowPartition(Transformation):
                     "dataflow_partition" + str(target_partition_id) + "_"
                 )
                 df_model_filename = df_model_dir + "/df_model.onnx"
+                df_model.cleanup()
                 df_model.save(df_model_filename)
                 # remove all dataflow nodes from the non-dataflow model
                 # keep track of where the dataflow part starts