From 03265c8c3ee6802d53592ec2c3eec63e975e305d Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Mon, 11 Nov 2019 06:23:26 -0500
Subject: [PATCH] [Code generation] Changed way of reading the output name of
 the fc layers needed for c++ hls code

---
 src/finn/backend/fpgadataflow/code_gen.py | 54 ++++++++---------------
 1 file changed, 19 insertions(+), 35 deletions(-)

diff --git a/src/finn/backend/fpgadataflow/code_gen.py b/src/finn/backend/fpgadataflow/code_gen.py
index f0ecd63a6..5e3927c01 100644
--- a/src/finn/backend/fpgadataflow/code_gen.py
+++ b/src/finn/backend/fpgadataflow/code_gen.py
@@ -53,43 +53,27 @@ def strm_prgm(model, code_gen_dict):
             )
 
 
-def computation_cmds(all_strmfcl, code_gen_dict):
+def computation_cmds(model, all_strmfcl, code_gen_dict):
     code_gen_dict["compute"] = []
     for i in range(len(all_strmfcl)):
-        if i == (len(all_strmfcl) - 1):
-            code_gen_dict["compute"].append(
-                "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
-                "({}, {}, {}, {}, numReps, {});".format(
-                    all_strmfcl[i].op_type,
-                    i,
-                    i,
-                    i,
-                    i,
-                    all_strmfcl[i].resDataType,
-                    all_strmfcl[i].input,
-                    all_strmfcl[i].output,
-                    all_strmfcl[i].weights,
-                    all_strmfcl[i].thresholds,
-                    all_strmfcl[i].resType,
-                )
-            )
-        else:
-            code_gen_dict["compute"].append(
-                "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
-                "({}, {}, {}, {}, numReps, {});".format(
-                    all_strmfcl[i].op_type,
-                    i,
-                    i,
-                    i,
-                    i,
-                    all_strmfcl[i].resDataType,
-                    all_strmfcl[i].input,
-                    all_strmfcl[i + 1].input,
-                    all_strmfcl[i].weights,
-                    all_strmfcl[i].thresholds,
-                    all_strmfcl[i].resType,
-                )
+        consumer = model.find_consumer(all_strmfcl[i].output)
+        output_name = consumer.output[0]
+        code_gen_dict["compute"].append(
+            "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
+            "({}, {}, {}, {}, numReps, {});".format(
+                all_strmfcl[i].op_type,
+                i,
+                i,
+                i,
+                i,
+                all_strmfcl[i].resDataType,
+                all_strmfcl[i].input,
+                output_name,
+                all_strmfcl[i].weights,
+                all_strmfcl[i].thresholds,
+                all_strmfcl[i].resType,
             )
+        )
 
 
 def config_cmds(model, code_gen_dict):
@@ -157,7 +141,7 @@ def code_generation(model):
     strm_prgm(model, code_gen_dict)
 
     # computation commands
-    computation_cmds(all_strmfcl, code_gen_dict)
+    computation_cmds(model, all_strmfcl, code_gen_dict)
 
     # print(code_gen_dict)
 
-- 
GitLab