diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py index 707289d393e2486780aed2c4af336dd3bafd37a6..3acfc7d8b004733131ee997f69aa4ac2aac88577 100644 --- a/src/finn/custom_op/fpgadataflow/thresholding_batch.py +++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py @@ -545,12 +545,10 @@ class Thresholding_Batch(HLSCustomOp): out = context[node.output[0]] out = 2 * out - 1 context[node.output[0]] = out + oshape = self.get_normal_output_shape() assert ( - context[node.output[0]].shape == self.get_folded_output_shape() + context[node.output[0]].shape == oshape ), """Output shape is not as expected""" - # reshape output to have expected shape - oshape = self.get_normal_output_shape() - context[node.output[0]] = context[node.output[0]].reshape(*oshape) elif mode == "rtlsim": sim = self.get_rtlsim() nbits = self.get_instream_width() @@ -691,9 +689,12 @@ class Thresholding_Batch(HLSCustomOp): ) ] elif mem_mode == "decoupled": + # note that numReps is set to 1 in the invocation below, since + # - for cppsim the repetition comes from the threshold stream reader+input + # - for synth the unit runs continuously anyway (ap_ctrl_none) self.code_gen_dict["$DOCOMPUTE$"] = [ """{}<{}, NumChannels1, PE1, {}, {}, ActVal1, ThresType1, NumSteps1> - (in0, out, weights, numReps);""".format( + (in0, out, weights, 1);""".format( "Thresholding_Stream_Batch", total_spatial_size, tmpl_args["TSrcI"],