Skip to content
Snippets Groups Projects
Commit 4186f203 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add a cleanup transform for generated files

parent 1c79b9f9
No related branches found
No related tags found
No related merge requests found
import os
import shutil
import finn.core.utils as util
import finn.custom_op.registry as registry
from finn.transformation import Transformation
class CleanUp(Transformation):
"""Remove any generated files for fpgadataflow nodes."""
def __init__(self):
super().__init__()
def apply(self, model):
for node in model.graph.node:
op_type = node.op_type
if node.domain == "finn":
backend_attribute = util.get_by_name(node.attribute, "backend")
backend_value = backend_attribute.s.decode("UTF-8")
if backend_value == "fpgadataflow":
try:
# lookup op_type in registry of CustomOps
inst = registry.custom_op[op_type](node)
code_gen_dir = inst.get_nodeattr("code_gen_dir")
if os.path.isdir(code_gen_dir):
shutil.rmtree(code_gen_dir)
inst.set_nodeattr("code_gen_dir", "")
inst.set_nodeattr("executable_path", "")
except KeyError:
# exception if op_type is not supported
raise Exception(
"Custom op_type %s is currently not supported." % op_type
)
return (model, False)
......@@ -5,6 +5,7 @@ from onnx import TensorProto, helper
import finn.core.utils as util
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fpgadataflow.cleanup import CleanUp
from finn.transformation.fpgadataflow.codegen import CodeGen
......@@ -69,3 +70,4 @@ def test_code_gen_trafo():
op type {} is empty!""".format(
node.op_type
)
model = model.transform(CleanUp())
......@@ -5,6 +5,7 @@ from onnx import TensorProto, helper
import finn.core.utils as util
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fpgadataflow.cleanup import CleanUp
from finn.transformation.fpgadataflow.codegen import CodeGen
from finn.transformation.fpgadataflow.compile import Compile
......@@ -66,3 +67,4 @@ def test_compilation_trafo():
op type {} does not exist!""".format(
node.op_type
)
model = model.transform(CleanUp())
......@@ -8,6 +8,7 @@ import finn.custom_op.xnorpopcount as xp
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.core.utils import gen_finn_dt_tensor
from finn.transformation.fpgadataflow.cleanup import CleanUp
from finn.transformation.fpgadataflow.codegen import CodeGen
from finn.transformation.fpgadataflow.compile import Compile
......@@ -120,3 +121,4 @@ def test_fpgadataflow_fclayer_noact(idt, wdt, nf, sf, mw, mh):
# execute model
y_produced = oxe.execute_onnx(model, input_dict)["outp"]
assert (y_produced.reshape(y_expected.shape) == y_expected).all()
model = model.transform(CleanUp())
......@@ -4,6 +4,7 @@ from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fpgadataflow.cleanup import CleanUp
from finn.transformation.fpgadataflow.codegen import CodeGen
from finn.transformation.fpgadataflow.compile import Compile
......@@ -117,3 +118,4 @@ def test_layer_streaming_maxpool_batch():
input_dict = {"in": input_tensor}
output_dict = oxe.execute_onnx(model, input_dict)
print(output_dict)
model = model.transform(CleanUp())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment