Skip to content
Snippets Groups Projects
Commit 8f1f5f8f authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[ModelWrapper] introduce cleanup function, call in partitions too

parent 65d8322f
No related branches found
No related tags found
No related merge requests found
......@@ -36,6 +36,11 @@ from onnx import TensorProto
import finn.util.basic as util
import finn.util.onnx as onnxutil
from finn.core.datatype import DataType
from finn.transformation.general import (
RemoveUnusedTensors,
RemoveStaticGraphInputs,
SortGraph,
)
class ModelWrapper:
......@@ -87,7 +92,7 @@ class ModelWrapper:
"""Runs given anaylsis_fxn on this model and return resulting dict."""
return analysis_fxn(self)
def transform(self, transformation, make_deepcopy=True):
def transform(self, transformation, make_deepcopy=True, cleanup=True):
"""Applies given Transformation repeatedly until no more changes can be made
and returns a transformed ModelWrapper instance.
......@@ -101,6 +106,22 @@ class ModelWrapper:
(transformed_model, model_was_changed) = transformation.apply(
transformed_model
)
if cleanup:
transformed_model.cleanup()
return transformed_model
def cleanup(self):
"Run cleanup transformations on the model."
transformed_model = self
cleanup_transforms = [
RemoveUnusedTensors(),
RemoveStaticGraphInputs(),
SortGraph(),
]
for trn in cleanup_transforms:
transformed_model = transformed_model.transform(
trn, cleanup=False, make_deepcopy=False
)
return transformed_model
def check_compatibility(self):
......
......@@ -112,6 +112,7 @@ class CreateDataflowPartition(Transformation):
"dataflow_partition" + str(target_partition_id) + "_"
)
df_model_filename = df_model_dir + "/df_model.onnx"
df_model.cleanup()
df_model.save(df_model_filename)
# remove all dataflow nodes from the non-dataflow model
# keep track of where the dataflow part starts
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment