diff --git a/src/finn/transformation/code_gen_transformation.py b/src/finn/transformation/code_gen_transformation.py index 7e0ec8429dcfeb1478f90184e24c6b8936781607..c34172343b2aafdd44d0a794617021e8bd45ef25 100644 --- a/src/finn/transformation/code_gen_transformation.py +++ b/src/finn/transformation/code_gen_transformation.py @@ -1,7 +1,10 @@ +import os +import tempfile as tmp + import finn.custom_op.registry as registry -def code_gen_transformation(node): +def code_gen_transformation(node, context): """Call custom implementation to generate code for single custom node and create folder that contains all the generated files""" op_type = node.op_type @@ -12,17 +15,49 @@ def code_gen_transformation(node): # get the path of the code generation directory if already set # check instance and check node attributes for value code_gen_dir = inst.code_gen_dir - print(code_gen_dir) + code_gen_dir = code_gen_dir.s.decode("UTF-8") + # parameter is empty if not code_gen_dir: - print("parameter is empty") - # create new directory, set the value and generate the code + tmp_dir = tmp.mkdtemp(prefix="code_gen_" + str(node.op_type) + "_") + inst.tmp_dir = tmp_dir + inst.code_generation(context) + # check if directory exists + if os.path.isdir(tmp_dir): + if len(os.listdir(tmp_dir)) == 0: + raise Exception("Code was not generated!") + else: + inst.code_gen_dir = tmp_dir + for attribute in node.attribute: + if attribute.name == "code_gen_dir": + attribute.s = tmp_dir.encode("UTF-8") + else: + raise Exception("Code was not generated!") + # there is already a code gen directory else: - print("parameter contains value") - # check directory for files if empty, - # delete directory and create new one - # otherwise just leave it that way + # check directory for files + if os.path.isdir(code_gen_dir): + if len(os.listdir(code_gen_dir)) == 0: + os.rmdir(code_gen_dir) + tmp_dir = tmp.mkdtemp(prefix="code_gen_" + str(node.op_type) + "_") + inst.tmp_dir = tmp_dir + inst.code_generation(context) + if os.path.isdir(tmp_dir): + if len(os.listdir(tmp_dir)) == 0: + raise Exception("Code was not generated!") + else: + inst.code_gen_dir = tmp_dir + for attribute in node.attribute: + if attribute.name == "code_gen_dir": + attribute.s = tmp_dir.encode("UTF-8") + else: + raise Exception("Code was not generated!") + else: + inst.code_gen_dir = tmp_dir + for attribute in node.attribute: + if attribute.name == "code_gen_dir": + attribute.s = tmp_dir.encode("UTF-8") except KeyError: # exception if op_type is not supported diff --git a/tests/fpgadataflow/test_code_gen_trafo.py b/tests/fpgadataflow/test_code_gen_trafo.py index 39e41f3d2f75e6f420005a7ea6579bcaa9aa6752..43de5cda6ec5318310066ef472174394653d0b74 100644 --- a/tests/fpgadataflow/test_code_gen_trafo.py +++ b/tests/fpgadataflow/test_code_gen_trafo.py @@ -29,7 +29,7 @@ def test_code_gen_trafo(): ["outp"], domain="finn", backend="fpgadataflow", - code_gen_dir="dummy_directory", + code_gen_dir="", executable_path="", resType="ap_resource_lut()", MW=mw, @@ -58,5 +58,8 @@ def test_code_gen_trafo(): T = np.zeros((1, 1)) model.set_initializer("thresh", T) - for nodes in model.graph.node: - cg_trafo.code_gen_transformation(nodes) + context = {} + context["weights"] = W + context["threshs"] = T + for node in model.graph.node: + cg_trafo.code_gen_transformation(node, context)