From 508dd37a49d691e2851358b2e7f82d19152b1c34 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Tue, 24 May 2022 15:56:40 +0200
Subject: [PATCH] [Threshold] bugfixes in standalone Thresholding op

- wrong number of reps causes cppsim bugs and reading from empty stream
- adjusted cppsim I/O wrappers to reflect changes in HLSCustomOp
---
 src/finn/custom_op/fpgadataflow/thresholding_batch.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
index 707289d39..3acfc7d8b 100644
--- a/src/finn/custom_op/fpgadataflow/thresholding_batch.py
+++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
@@ -545,12 +545,10 @@ class Thresholding_Batch(HLSCustomOp):
                 out = context[node.output[0]]
                 out = 2 * out - 1
                 context[node.output[0]] = out
+            oshape = self.get_normal_output_shape()
             assert (
-                context[node.output[0]].shape == self.get_folded_output_shape()
+                context[node.output[0]].shape == oshape
             ), """Output shape is not as expected"""
-            # reshape output to have expected shape
-            oshape = self.get_normal_output_shape()
-            context[node.output[0]] = context[node.output[0]].reshape(*oshape)
         elif mode == "rtlsim":
             sim = self.get_rtlsim()
             nbits = self.get_instream_width()
@@ -691,9 +689,12 @@ class Thresholding_Batch(HLSCustomOp):
                 )
             ]
         elif mem_mode == "decoupled":
+            # note that numReps is set to 1 in the invocation below, since
+            # - for cppsim the repetition comes from the threshold stream reader+input
+            # - for synth the unit runs continuously anyway (ap_ctrl_none)
             self.code_gen_dict["$DOCOMPUTE$"] = [
                 """{}<{}, NumChannels1, PE1, {}, {}, ActVal1, ThresType1, NumSteps1>
-                (in0, out, weights, numReps);""".format(
+                (in0, out, weights, 1);""".format(
                     "Thresholding_Stream_Batch",
                     total_spatial_size,
                     tmpl_args["TSrcI"],
-- 
GitLab