diff --git a/src/finn/transformation/code_gen_transformation.py b/src/finn/transformation/code_gen_transformation.py index c34172343b2aafdd44d0a794617021e8bd45ef25..c015ca4ca6291e148410d9ca544ca904646e4ee2 100644 --- a/src/finn/transformation/code_gen_transformation.py +++ b/src/finn/transformation/code_gen_transformation.py @@ -2,9 +2,10 @@ import os import tempfile as tmp import finn.custom_op.registry as registry +from finn.transformation import Transformation -def code_gen_transformation(node, context): +def code_gen_transformation(node, context, model): """Call custom implementation to generate code for single custom node and create folder that contains all the generated files""" op_type = node.op_type @@ -28,9 +29,7 @@ def code_gen_transformation(node, context): raise Exception("Code was not generated!") else: inst.code_gen_dir = tmp_dir - for attribute in node.attribute: - if attribute.name == "code_gen_dir": - attribute.s = tmp_dir.encode("UTF-8") + model.set_attribute(node, "code_gen_dir", tmp_dir) else: raise Exception("Code was not generated!") @@ -48,17 +47,27 @@ def code_gen_transformation(node, context): raise Exception("Code was not generated!") else: inst.code_gen_dir = tmp_dir - for attribute in node.attribute: - if attribute.name == "code_gen_dir": - attribute.s = tmp_dir.encode("UTF-8") + model.set_attribute(node, "code_gen_dir", tmp_dir) else: raise Exception("Code was not generated!") else: inst.code_gen_dir = tmp_dir - for attribute in node.attribute: - if attribute.name == "code_gen_dir": - attribute.s = tmp_dir.encode("UTF-8") + model.set_attribute(node, "code_gen_dir", tmp_dir) except KeyError: # exception if op_type is not supported raise Exception("Custom op_type %s is currently not supported." % op_type) + + +class CodeGen(Transformation): + """Code generation for all nodes in model""" + + def apply(self, model): + W = model.get_initializer("weights") + T = model.get_initializer("thresh") + context = {} + context["weights"] = W + context["thresh"] = T + for node in model.graph.node: + code_gen_transformation(node, context, model) + return (model, False) diff --git a/tests/fpgadataflow/test_code_gen_trafo.py b/tests/transformation/test_code_gen_trafo.py similarity index 88% rename from tests/fpgadataflow/test_code_gen_trafo.py rename to tests/transformation/test_code_gen_trafo.py index 43de5cda6ec5318310066ef472174394653d0b74..4f050fc1f7738364fde8391fff5e0c8925835895 100644 --- a/tests/fpgadataflow/test_code_gen_trafo.py +++ b/tests/transformation/test_code_gen_trafo.py @@ -2,9 +2,9 @@ import numpy as np from onnx import TensorProto, helper import finn.core.utils as util -import finn.transformation.code_gen_transformation as cg_trafo from finn.core.datatype import DataType from finn.core.modelwrapper import ModelWrapper +from finn.transformation.code_gen_transformation import CodeGen def test_code_gen_trafo(): @@ -58,8 +58,4 @@ def test_code_gen_trafo(): T = np.zeros((1, 1)) model.set_initializer("thresh", T) - context = {} - context["weights"] = W - context["threshs"] = T - for node in model.graph.node: - cg_trafo.code_gen_transformation(node, context) + model = model.transform(CodeGen())