Skip to content
Snippets Groups Projects
Commit 03265c8c authored by auphelia's avatar auphelia
Browse files

[Code generation] Changed way of reading the output name of the fc layers needed for c++ hls code

parent dde0bead
No related branches found
No related tags found
No related merge requests found
...@@ -53,43 +53,27 @@ def strm_prgm(model, code_gen_dict): ...@@ -53,43 +53,27 @@ def strm_prgm(model, code_gen_dict):
) )
def computation_cmds(all_strmfcl, code_gen_dict): def computation_cmds(model, all_strmfcl, code_gen_dict):
code_gen_dict["compute"] = [] code_gen_dict["compute"] = []
for i in range(len(all_strmfcl)): for i in range(len(all_strmfcl)):
if i == (len(all_strmfcl) - 1): consumer = model.find_consumer(all_strmfcl[i].output)
code_gen_dict["compute"].append( output_name = consumer.output[0]
"{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> " code_gen_dict["compute"].append(
"({}, {}, {}, {}, numReps, {});".format( "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
all_strmfcl[i].op_type, "({}, {}, {}, {}, numReps, {});".format(
i, all_strmfcl[i].op_type,
i, i,
i, i,
i, i,
all_strmfcl[i].resDataType, i,
all_strmfcl[i].input, all_strmfcl[i].resDataType,
all_strmfcl[i].output, all_strmfcl[i].input,
all_strmfcl[i].weights, output_name,
all_strmfcl[i].thresholds, all_strmfcl[i].weights,
all_strmfcl[i].resType, all_strmfcl[i].thresholds,
) all_strmfcl[i].resType,
)
else:
code_gen_dict["compute"].append(
"{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
"({}, {}, {}, {}, numReps, {});".format(
all_strmfcl[i].op_type,
i,
i,
i,
i,
all_strmfcl[i].resDataType,
all_strmfcl[i].input,
all_strmfcl[i + 1].input,
all_strmfcl[i].weights,
all_strmfcl[i].thresholds,
all_strmfcl[i].resType,
)
) )
)
def config_cmds(model, code_gen_dict): def config_cmds(model, code_gen_dict):
...@@ -157,7 +141,7 @@ def code_generation(model): ...@@ -157,7 +141,7 @@ def code_generation(model):
strm_prgm(model, code_gen_dict) strm_prgm(model, code_gen_dict)
# computation commands # computation commands
computation_cmds(all_strmfcl, code_gen_dict) computation_cmds(model, all_strmfcl, code_gen_dict)
# print(code_gen_dict) # print(code_gen_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment