diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
index 03cfbe0e09be092cf9446f29110b01ef0cc1e367..d9745acf63c4685b3369ac379abde0a6c5a3f157 100644
--- a/src/finn/custom_op/fpgadataflow/thresholding_batch.py
+++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
@@ -602,17 +602,17 @@ class Thresholding_Batch(HLSCustomOp):
 
     # TODO check and add whatever missing
     def defines(self, var):
-        if self.get_nodeattr("mem_mode") == "const":
-            numReps = 1
-        else:
-            numInputVectors = list(self.get_nodeattr("numInputVectors"))
-            numReps = int(np.prod(numInputVectors))
+        numReps = 1
+        numInputVectors = list(self.get_nodeattr("numInputVectors"))
+        total_spatial_size = int(np.prod(numInputVectors))
 
         self.code_gen_dict["$DEFINES$"] = [
-            """#define NumChannels1 {}\n #define PE1 {}\n #define numReps {}""".format(
+            """#define NumChannels1 {}\n #define PE1 {}\n #define numReps {}\n
+               #define ImgDim1 {}""".format(
                 self.get_nodeattr("NumChannels"),
                 self.get_nodeattr("PE"),
                 numReps,
+                total_spatial_size,
             )
         ]
         if self.get_nodeattr("mem_mode") == "decoupled":
@@ -653,7 +653,7 @@ class Thresholding_Batch(HLSCustomOp):
             npy_in = "%s/thresholds.npy" % code_gen_dir
 
             self.code_gen_dict["$READNPYDATA$"].append(
-                'npy2apintstream<%s, %s, %d, %s>("%s", weights, false, numReps);'
+                'npy2apintstream<%s, %s, %d, %s>("%s", weights, false, ImgDim1);'
                 % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
             )
 
@@ -675,18 +675,13 @@ class Thresholding_Batch(HLSCustomOp):
 
     def docompute(self):
         tmpl_args = self.get_template_param_values()
-        # TODO: why put some template parameters into defines and not others?
-        # should ImgDim be defined or just filled in here like we do now?
         node = self.onnx_node
-        inp_vecs = self.get_nodeattr("numInputVectors")
-        total_spatial_size = int(np.prod(inp_vecs))
         mem_mode = self.get_nodeattr("mem_mode")
         if mem_mode == "const":
             self.code_gen_dict["$DOCOMPUTE$"] = [
-                """{}<{}, NumChannels1, PE1, {}, {}>
+                """{}<ImgDim1, NumChannels1, PE1, {}, {}>
                 (in0, out, threshs, numReps);""".format(
                     node.op_type,
-                    total_spatial_size,
                     tmpl_args["TSrcI"],
                     tmpl_args["TDstI"],
                 )
@@ -696,10 +691,9 @@ class Thresholding_Batch(HLSCustomOp):
             # - for cppsim the repetition comes from the threshold stream reader+input
             # - for synth the unit runs continuously anyway (ap_ctrl_none)
             self.code_gen_dict["$DOCOMPUTE$"] = [
-                """{}<{}, NumChannels1, PE1, {}, {}, ActVal1, ThresType1, NumSteps1>
-                (in0, out, weights, 1);""".format(
+                """{}<ImgDim1, NumChannels1, PE1, {}, {}, ActVal1, ThresType1, NumSteps1>
+                (in0, out, weights, numReps);""".format(
                     "Thresholding_Stream_Batch",
-                    total_spatial_size,
                     tmpl_args["TSrcI"],
                     tmpl_args["TDstI"],
                 )