diff --git a/src/finn/backend/fpgadataflow/code_gen.py b/src/finn/backend/fpgadataflow/code_gen.py index f0ecd63a6d655e59760f732c327c64c283db81b7..5e3927c013dd3ca23154e58020015e3eab0105ee 100644 --- a/src/finn/backend/fpgadataflow/code_gen.py +++ b/src/finn/backend/fpgadataflow/code_gen.py @@ -53,43 +53,27 @@ def strm_prgm(model, code_gen_dict): ) -def computation_cmds(all_strmfcl, code_gen_dict): +def computation_cmds(model, all_strmfcl, code_gen_dict): code_gen_dict["compute"] = [] for i in range(len(all_strmfcl)): - if i == (len(all_strmfcl) - 1): - code_gen_dict["compute"].append( - "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> " - "({}, {}, {}, {}, numReps, {});".format( - all_strmfcl[i].op_type, - i, - i, - i, - i, - all_strmfcl[i].resDataType, - all_strmfcl[i].input, - all_strmfcl[i].output, - all_strmfcl[i].weights, - all_strmfcl[i].thresholds, - all_strmfcl[i].resType, - ) - ) - else: - code_gen_dict["compute"].append( - "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> " - "({}, {}, {}, {}, numReps, {});".format( - all_strmfcl[i].op_type, - i, - i, - i, - i, - all_strmfcl[i].resDataType, - all_strmfcl[i].input, - all_strmfcl[i + 1].input, - all_strmfcl[i].weights, - all_strmfcl[i].thresholds, - all_strmfcl[i].resType, - ) + consumer = model.find_consumer(all_strmfcl[i].output) + output_name = consumer.output[0] + code_gen_dict["compute"].append( + "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> " + "({}, {}, {}, {}, numReps, {});".format( + all_strmfcl[i].op_type, + i, + i, + i, + i, + all_strmfcl[i].resDataType, + all_strmfcl[i].input, + output_name, + all_strmfcl[i].weights, + all_strmfcl[i].thresholds, + all_strmfcl[i].resType, ) + ) def config_cmds(model, code_gen_dict): @@ -157,7 +141,7 @@ def code_generation(model): strm_prgm(model, code_gen_dict) # computation commands - computation_cmds(all_strmfcl, code_gen_dict) + computation_cmds(model, all_strmfcl, code_gen_dict) # print(code_gen_dict)