diff --git a/src/finn/custom_op/fpgadataflow/pool_batch.py b/src/finn/custom_op/fpgadataflow/pool_batch.py
index 1aedf7056b1381a8904ac63b3a64f86b1243cb36..c7edc24d0e24eef1154293caca2519ab3aa68358 100644
--- a/src/finn/custom_op/fpgadataflow/pool_batch.py
+++ b/src/finn/custom_op/fpgadataflow/pool_batch.py
@@ -45,7 +45,7 @@ class Pool_Batch(HLSCustomOp):
 
     # note: the actual data layout produced by the hlslib kernels is different
     # for depthwise ops.
-    # * depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/SIMD, K, K, SIMD)
+    # * depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/PE, K, K, PE)
 
     Channels can be folded using PE (SIMD from the input perspective)
     TODO: doc
@@ -159,7 +159,7 @@ class Pool_Batch(HLSCustomOp):
     def infer_node_datatype(self, model):
         node = self.onnx_node
         # data type stays the same
-        dtype = model.get_tensor_datatype(node.input[0])
+        dtype = self.get_output_datatype()
         model.set_tensor_datatype(node.output[0], dtype)
 
     def verify_node(self):