diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
index 3bcc5c05cf5c40c3b1e1e73b7aab4f607ee1f04c..72ee2f7af68d93e81aad419eb479de872330cb76 100644
--- a/src/finn/custom_op/fpgadataflow/thresholding_batch.py
+++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
@@ -211,6 +211,8 @@ class Thresholding_Batch(HLSCustomOp):
             threshold_tensor
         ).all(), "Thresholds can't be expressed with type %s" % str(tdt)
         self.set_nodeattr("weightDataType", tdt.name)
+        # Update QONNX DataType of tensor for consistency
+        model.set_tensor_datatype(self.onnx_node.input[1], tdt)
         return DataType[self.get_nodeattr("weightDataType")]
 
     def get_instream_width(self, ind=0):