From 8f1f5f8f9164447560c650fa0396d28ca6ecb935 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 18 Aug 2020 20:27:42 +0200
Subject: [PATCH] [ModelWrapper] introduce cleanup function, call in partitions
 too

---
 src/finn/core/modelwrapper.py                 | 23 ++++++++++++++++++-
 .../fpgadataflow/create_dataflow_partition.py |  1 +
 2 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index 646add188..98b234592 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -36,6 +36,11 @@ from onnx import TensorProto
 import finn.util.basic as util
 import finn.util.onnx as onnxutil
 from finn.core.datatype import DataType
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    SortGraph,
+)
 
 
 class ModelWrapper:
@@ -87,7 +92,7 @@ class ModelWrapper:
         """Runs given anaylsis_fxn on this model and return resulting dict."""
         return analysis_fxn(self)
 
-    def transform(self, transformation, make_deepcopy=True):
+    def transform(self, transformation, make_deepcopy=True, cleanup=True):
         """Applies given Transformation repeatedly until no more changes can be made
         and returns a transformed ModelWrapper instance.
 
@@ -101,6 +106,22 @@ class ModelWrapper:
             (transformed_model, model_was_changed) = transformation.apply(
                 transformed_model
             )
+        if cleanup:
+            transformed_model.cleanup()
+        return transformed_model
+
+    def cleanup(self):
+        "Run cleanup transformations on the model."
+        transformed_model = self
+        cleanup_transforms = [
+            RemoveUnusedTensors(),
+            RemoveStaticGraphInputs(),
+            SortGraph(),
+        ]
+        for trn in cleanup_transforms:
+            transformed_model = transformed_model.transform(
+                trn, cleanup=False, make_deepcopy=False
+            )
         return transformed_model
 
     def check_compatibility(self):
diff --git a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
index 5ec4ab14d..fb8b4358a 100644
--- a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
+++ b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
@@ -112,6 +112,7 @@ class CreateDataflowPartition(Transformation):
                     "dataflow_partition" + str(target_partition_id) + "_"
                 )
                 df_model_filename = df_model_dir + "/df_model.onnx"
+                df_model.cleanup()
                 df_model.save(df_model_filename)
                 # remove all dataflow nodes from the non-dataflow model
                 # keep track of where the dataflow part starts
-- 
GitLab