diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index 646add188c5d475cf37ccd33cf24d29d61754ae1..98b234592ebe0c704fafd1eed980325d8566e7e2 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -36,6 +36,11 @@ from onnx import TensorProto import finn.util.basic as util import finn.util.onnx as onnxutil from finn.core.datatype import DataType +from finn.transformation.general import ( + RemoveUnusedTensors, + RemoveStaticGraphInputs, + SortGraph, +) class ModelWrapper: @@ -87,7 +92,7 @@ class ModelWrapper: """Runs given anaylsis_fxn on this model and return resulting dict.""" return analysis_fxn(self) - def transform(self, transformation, make_deepcopy=True): + def transform(self, transformation, make_deepcopy=True, cleanup=True): """Applies given Transformation repeatedly until no more changes can be made and returns a transformed ModelWrapper instance. @@ -101,6 +106,22 @@ class ModelWrapper: (transformed_model, model_was_changed) = transformation.apply( transformed_model ) + if cleanup: + transformed_model.cleanup() + return transformed_model + + def cleanup(self): + "Run cleanup transformations on the model." + transformed_model = self + cleanup_transforms = [ + RemoveUnusedTensors(), + RemoveStaticGraphInputs(), + SortGraph(), + ] + for trn in cleanup_transforms: + transformed_model = transformed_model.transform( + trn, cleanup=False, make_deepcopy=False + ) return transformed_model def check_compatibility(self): diff --git a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py index 5ec4ab14d65d63523856a6bb107bf75c1ca5a261..fb8b4358abd772d13c355f797649dc3b51975b4d 100644 --- a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py +++ b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py @@ -112,6 +112,7 @@ class CreateDataflowPartition(Transformation): "dataflow_partition" + str(target_partition_id) + "_" ) df_model_filename = df_model_dir + "/df_model.onnx" + df_model.cleanup() df_model.save(df_model_filename) # remove all dataflow nodes from the non-dataflow model # keep track of where the dataflow part starts