From 8f1f5f8f9164447560c650fa0396d28ca6ecb935 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 18 Aug 2020 20:27:42 +0200 Subject: [PATCH] [ModelWrapper] introduce cleanup function, call in partitions too --- src/finn/core/modelwrapper.py | 23 ++++++++++++++++++- .../fpgadataflow/create_dataflow_partition.py | 1 + 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index 646add188..98b234592 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -36,6 +36,11 @@ from onnx import TensorProto import finn.util.basic as util import finn.util.onnx as onnxutil from finn.core.datatype import DataType +from finn.transformation.general import ( + RemoveUnusedTensors, + RemoveStaticGraphInputs, + SortGraph, +) class ModelWrapper: @@ -87,7 +92,7 @@ class ModelWrapper: """Runs given anaylsis_fxn on this model and return resulting dict.""" return analysis_fxn(self) - def transform(self, transformation, make_deepcopy=True): + def transform(self, transformation, make_deepcopy=True, cleanup=True): """Applies given Transformation repeatedly until no more changes can be made and returns a transformed ModelWrapper instance. @@ -101,6 +106,22 @@ class ModelWrapper: (transformed_model, model_was_changed) = transformation.apply( transformed_model ) + if cleanup: + transformed_model.cleanup() + return transformed_model + + def cleanup(self): + "Run cleanup transformations on the model." + transformed_model = self + cleanup_transforms = [ + RemoveUnusedTensors(), + RemoveStaticGraphInputs(), + SortGraph(), + ] + for trn in cleanup_transforms: + transformed_model = transformed_model.transform( + trn, cleanup=False, make_deepcopy=False + ) return transformed_model def check_compatibility(self): diff --git a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py index 5ec4ab14d..fb8b4358a 100644 --- a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py +++ b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py @@ -112,6 +112,7 @@ class CreateDataflowPartition(Transformation): "dataflow_partition" + str(target_partition_id) + "_" ) df_model_filename = df_model_dir + "/df_model.onnx" + df_model.cleanup() df_model.save(df_model_filename) # remove all dataflow nodes from the non-dataflow model # keep track of where the dataflow part starts -- GitLab