Skip to content
Snippets Groups Projects
Commit 9214eed2 authored by auphelia's avatar auphelia
Browse files

[Transformations] Refactoring of code

parent c6fc9d55
No related branches found
No related tags found
No related merge requests found
......@@ -41,12 +41,14 @@ class HLSCustomOp(CustomOp):
self.code_gen_dir = util.get_by_name(onnx_node.attribute, "code_gen_dir")
self.executable_path = ""
def code_generation(self, context):
def code_generation(self, model):
node = self.onnx_node
if "weights" in context:
self.generate_weights(context)
if "thresh" in context:
self.generate_thresholds(context)
if node.op_type == "StreamingFCLayer_Batch":
self.generate_weights(model)
try:
self.generate_thresholds(model)
except:
pass
self.global_includes()
self.defines()
self.read_npy_data()
......
......@@ -148,9 +148,9 @@ class StreamingFCLayer_Batch(HLSCustomOp):
assert ret.shape[2] == n_thres_steps
return ret
def generate_weights(self, context):
def generate_weights(self, model):
weights = context["weights"]
weights = model.get_initializer(self.onnx_node.input[1])
# convert weights into hlslib-compatible format
weight_tensor = self.get_hls_compatible_weight_tensor(weights)
export_wdt = self.get_weight_datatype()
......@@ -184,8 +184,8 @@ class StreamingFCLayer_Batch(HLSCustomOp):
f_weights.write(weight_hls_code)
f_weights.close()
def generate_thresholds(self, context):
thresholds = context["thresh"]
def generate_thresholds(self, model):
thresholds = model.get_initializer(self.onnx_node.input[2])
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
# use UINT32 threshold export for bipolar times bipolar
......
......@@ -5,7 +5,7 @@ import finn.custom_op.registry as registry
from finn.transformation import Transformation
def code_gen_transformation(node, context, model):
def code_gen_transformation(node, model):
"""Call custom implementation to generate code for single custom node
and create folder that contains all the generated files"""
op_type = node.op_type
......@@ -21,7 +21,7 @@ def code_gen_transformation(node, context, model):
if not code_gen_dir:
tmp_dir = tmp.mkdtemp(prefix="code_gen_" + str(node.op_type) + "_")
inst.tmp_dir = tmp_dir
inst.code_generation(context)
inst.code_generation(model)
# check if directory exists
if os.path.isdir(tmp_dir):
if len(os.listdir(tmp_dir)) == 0:
......@@ -40,7 +40,7 @@ def code_gen_transformation(node, context, model):
os.rmdir(code_gen_dir)
tmp_dir = tmp.mkdtemp(prefix="code_gen_" + str(node.op_type) + "_")
inst.tmp_dir = tmp_dir
inst.code_generation(context)
inst.code_generation(model)
if os.path.isdir(tmp_dir):
if len(os.listdir(tmp_dir)) == 0:
raise Exception("Code was not generated!")
......@@ -63,11 +63,6 @@ class CodeGen(Transformation):
"""Code generation for all nodes in model"""
def apply(self, model):
W = model.get_initializer("weights")
T = model.get_initializer("thresh")
context = {}
context["weights"] = W
context["thresh"] = T
for node in model.graph.node:
code_gen_transformation(node, context, model)
code_gen_transformation(node, model)
return (model, False)
......@@ -6,21 +6,18 @@ from onnx import TensorProto, helper
import finn.core.utils as util
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.code_gen_transformation import CodeGen
from finn.transformation.fpgadataflow.code_gen_transformation import CodeGen
def test_code_gen_trafo():
idt = wdt = odt = DataType.BIPOLAR
tdt = DataType.UINT32
mw = 8
mh = 8
pe = 4
simd = 4
wmem = mw * mh // (pe * simd)
assert mw * mh == wmem * pe * simd
nf = mh // pe
sf = mw // simd
tmem = nf
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, sf, simd])
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, nf, pe])
......@@ -39,7 +36,7 @@ def test_code_gen_trafo():
SIMD=simd,
PE=pe,
WMEM=wmem,
TMEM=tmem,
TMEM=0,
inputDataType=idt.name,
weightDataType=wdt.name,
outputDataType=odt.name,
......@@ -54,11 +51,8 @@ def test_code_gen_trafo():
model.set_tensor_datatype("inp", idt)
model.set_tensor_datatype("outp", odt)
model.set_tensor_datatype("weights", wdt)
W = util.gen_finn_dt_tensor(wdt, (mh, mw))
W = util.gen_finn_dt_tensor(wdt, (mw, mh))
model.set_initializer("weights", W)
model.set_tensor_datatype("thresh", tdt)
T = np.zeros((1, 1))
model.set_initializer("thresh", T)
model = model.transform(CodeGen())
for node in model.graph.node:
......
......@@ -6,19 +6,17 @@ from onnx import TensorProto, helper
import finn.core.utils as util
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.code_gen_transformation import CodeGen
from finn.transformation.compilation_transformation import Compilation
from finn.transformation.fpgadataflow.code_gen_transformation import CodeGen
from finn.transformation.fpgadataflow.compilation_transformation import Compilation
def test_compilation_trafo():
idt = wdt = odt = DataType.BIPOLAR
tdt = DataType.UINT32
mw = 8
mh = 8
pe = 4
simd = 4
wmem = mw * mh // (pe * simd)
assert mw * mh == wmem * pe * simd
nf = mh // pe
sf = mw // simd
tmem = nf
......@@ -40,7 +38,7 @@ def test_compilation_trafo():
SIMD=simd,
PE=pe,
WMEM=wmem,
TMEM=tmem,
TMEM=0,
inputDataType=idt.name,
weightDataType=wdt.name,
outputDataType=odt.name,
......@@ -55,11 +53,8 @@ def test_compilation_trafo():
model.set_tensor_datatype("inp", idt)
model.set_tensor_datatype("outp", odt)
model.set_tensor_datatype("weights", wdt)
W = util.gen_finn_dt_tensor(wdt, (mh, mw))
W = util.gen_finn_dt_tensor(wdt, (mw, mh))
model.set_initializer("weights", W)
model.set_tensor_datatype("thresh", tdt)
T = np.zeros((1, 1))
model.set_initializer("thresh", T)
model = model.transform(CodeGen())
model = model.transform(Compilation())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment