diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py index 3bcc5c05cf5c40c3b1e1e73b7aab4f607ee1f04c..72ee2f7af68d93e81aad419eb479de872330cb76 100644 --- a/src/finn/custom_op/fpgadataflow/thresholding_batch.py +++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py @@ -211,6 +211,8 @@ class Thresholding_Batch(HLSCustomOp): threshold_tensor ).all(), "Thresholds can't be expressed with type %s" % str(tdt) self.set_nodeattr("weightDataType", tdt.name) + # Update QONNX DataType of tensor for consistency + model.set_tensor_datatype(self.onnx_node.input[1], tdt) return DataType[self.get_nodeattr("weightDataType")] def get_instream_width(self, ind=0):