Skip to content
Snippets Groups Projects
Commit c3f0e0c9 authored by auphelia's avatar auphelia
Browse files

[Code gen trafo] Restructured test and function to fit the transformation into...

[Code gen trafo] Restructured test and function to fit the transformation into finn transformation layout
parent 38cd3d1f
No related branches found
No related tags found
No related merge requests found
...@@ -2,9 +2,10 @@ import os ...@@ -2,9 +2,10 @@ import os
import tempfile as tmp import tempfile as tmp
import finn.custom_op.registry as registry import finn.custom_op.registry as registry
from finn.transformation import Transformation
def code_gen_transformation(node, context): def code_gen_transformation(node, context, model):
"""Call custom implementation to generate code for single custom node """Call custom implementation to generate code for single custom node
and create folder that contains all the generated files""" and create folder that contains all the generated files"""
op_type = node.op_type op_type = node.op_type
...@@ -28,9 +29,7 @@ def code_gen_transformation(node, context): ...@@ -28,9 +29,7 @@ def code_gen_transformation(node, context):
raise Exception("Code was not generated!") raise Exception("Code was not generated!")
else: else:
inst.code_gen_dir = tmp_dir inst.code_gen_dir = tmp_dir
for attribute in node.attribute: model.set_attribute(node, "code_gen_dir", tmp_dir)
if attribute.name == "code_gen_dir":
attribute.s = tmp_dir.encode("UTF-8")
else: else:
raise Exception("Code was not generated!") raise Exception("Code was not generated!")
...@@ -48,17 +47,27 @@ def code_gen_transformation(node, context): ...@@ -48,17 +47,27 @@ def code_gen_transformation(node, context):
raise Exception("Code was not generated!") raise Exception("Code was not generated!")
else: else:
inst.code_gen_dir = tmp_dir inst.code_gen_dir = tmp_dir
for attribute in node.attribute: model.set_attribute(node, "code_gen_dir", tmp_dir)
if attribute.name == "code_gen_dir":
attribute.s = tmp_dir.encode("UTF-8")
else: else:
raise Exception("Code was not generated!") raise Exception("Code was not generated!")
else: else:
inst.code_gen_dir = tmp_dir inst.code_gen_dir = tmp_dir
for attribute in node.attribute: model.set_attribute(node, "code_gen_dir", tmp_dir)
if attribute.name == "code_gen_dir":
attribute.s = tmp_dir.encode("UTF-8")
except KeyError: except KeyError:
# exception if op_type is not supported # exception if op_type is not supported
raise Exception("Custom op_type %s is currently not supported." % op_type) raise Exception("Custom op_type %s is currently not supported." % op_type)
class CodeGen(Transformation):
"""Code generation for all nodes in model"""
def apply(self, model):
W = model.get_initializer("weights")
T = model.get_initializer("thresh")
context = {}
context["weights"] = W
context["thresh"] = T
for node in model.graph.node:
code_gen_transformation(node, context, model)
return (model, False)
...@@ -2,9 +2,9 @@ import numpy as np ...@@ -2,9 +2,9 @@ import numpy as np
from onnx import TensorProto, helper from onnx import TensorProto, helper
import finn.core.utils as util import finn.core.utils as util
import finn.transformation.code_gen_transformation as cg_trafo
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
from finn.transformation.code_gen_transformation import CodeGen
def test_code_gen_trafo(): def test_code_gen_trafo():
...@@ -58,8 +58,4 @@ def test_code_gen_trafo(): ...@@ -58,8 +58,4 @@ def test_code_gen_trafo():
T = np.zeros((1, 1)) T = np.zeros((1, 1))
model.set_initializer("thresh", T) model.set_initializer("thresh", T)
context = {} model = model.transform(CodeGen())
context["weights"] = W
context["threshs"] = T
for node in model.graph.node:
cg_trafo.code_gen_transformation(node, context)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment