diff --git a/src/finn/backend/fpgadataflow/code_gen.py b/src/finn/backend/fpgadataflow/code_gen.py
index f0ecd63a6d655e59760f732c327c64c283db81b7..5e3927c013dd3ca23154e58020015e3eab0105ee 100644
--- a/src/finn/backend/fpgadataflow/code_gen.py
+++ b/src/finn/backend/fpgadataflow/code_gen.py
@@ -53,43 +53,27 @@ def strm_prgm(model, code_gen_dict):
             )
 
 
-def computation_cmds(all_strmfcl, code_gen_dict):
+def computation_cmds(model, all_strmfcl, code_gen_dict):
     code_gen_dict["compute"] = []
     for i in range(len(all_strmfcl)):
-        if i == (len(all_strmfcl) - 1):
-            code_gen_dict["compute"].append(
-                "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
-                "({}, {}, {}, {}, numReps, {});".format(
-                    all_strmfcl[i].op_type,
-                    i,
-                    i,
-                    i,
-                    i,
-                    all_strmfcl[i].resDataType,
-                    all_strmfcl[i].input,
-                    all_strmfcl[i].output,
-                    all_strmfcl[i].weights,
-                    all_strmfcl[i].thresholds,
-                    all_strmfcl[i].resType,
-                )
-            )
-        else:
-            code_gen_dict["compute"].append(
-                "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
-                "({}, {}, {}, {}, numReps, {});".format(
-                    all_strmfcl[i].op_type,
-                    i,
-                    i,
-                    i,
-                    i,
-                    all_strmfcl[i].resDataType,
-                    all_strmfcl[i].input,
-                    all_strmfcl[i + 1].input,
-                    all_strmfcl[i].weights,
-                    all_strmfcl[i].thresholds,
-                    all_strmfcl[i].resType,
-                )
+        consumer = model.find_consumer(all_strmfcl[i].output)
+        output_name = consumer.output[0]
+        code_gen_dict["compute"].append(
+            "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
+            "({}, {}, {}, {}, numReps, {});".format(
+                all_strmfcl[i].op_type,
+                i,
+                i,
+                i,
+                i,
+                all_strmfcl[i].resDataType,
+                all_strmfcl[i].input,
+                output_name,
+                all_strmfcl[i].weights,
+                all_strmfcl[i].thresholds,
+                all_strmfcl[i].resType,
             )
+        )
 
 
 def config_cmds(model, code_gen_dict):
@@ -157,7 +141,7 @@ def code_generation(model):
     strm_prgm(model, code_gen_dict)
 
     # computation commands
-    computation_cmds(all_strmfcl, code_gen_dict)
+    computation_cmds(model, all_strmfcl, code_gen_dict)
 
     # print(code_gen_dict)