Skip to content
Snippets Groups Projects
Commit 915ec0d4 authored by auphelia's avatar auphelia
Browse files

[Code generation] Added some of the missing layer parameters (WMEM, TMEM and...

[Code generation] Added some of the missing layer parameters (WMEM, TMEM and API) to get_layer_parameters()
parent d757fb53
No related branches found
No related tags found
No related merge requests found
def get_layer_attributes(node):
def get_layer_parameters(model, node):
# Layer attributes
num_attr = len(node.attribute)
for k in range(num_attr):
......@@ -14,6 +14,14 @@ def get_layer_attributes(node):
L_resDataType = node.attribute[k].s
if node.attribute[k].name == "resType":
L_resType = node.attribute[k].s
# get other parameters
weights_shape = model.get_tensor_shape(node.input[1])
thresholds_shape = model.get_tensor_shape(node.input[2])
L_WMEM = weights_shape[2]
L_TMEM = thresholds_shape[0]
L_API = thresholds_shape[2]
return [
L_PE,
L_SIMD,
......@@ -21,6 +29,9 @@ def get_layer_attributes(node):
L_MW,
L_resDataType.decode("utf-8"),
L_resType.decode("utf-8"),
L_WMEM,
L_TMEM,
L_API,
]
......@@ -87,8 +98,18 @@ def computation_cmds(model, code_gen_dict):
weights = node.input[1]
thresholds = node.input[2]
outp = node.output[0]
# get layer attributes
[PE, SIMD, MH, MW, resDataType, resType] = get_layer_attributes(node)
# get layer parameters
[
PE,
SIMD,
MH,
MW,
resDataType,
resType,
WMEM,
TMEM,
API,
] = get_layer_parameters(model, node)
code_gen_dict["compute"].append(
"{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment