diff --git a/src/finn/custom_op/fpgadataflow/pool_batch.py b/src/finn/custom_op/fpgadataflow/pool_batch.py index 1aedf7056b1381a8904ac63b3a64f86b1243cb36..c7edc24d0e24eef1154293caca2519ab3aa68358 100644 --- a/src/finn/custom_op/fpgadataflow/pool_batch.py +++ b/src/finn/custom_op/fpgadataflow/pool_batch.py @@ -45,7 +45,7 @@ class Pool_Batch(HLSCustomOp): # note: the actual data layout produced by the hlslib kernels is different # for depthwise ops. - # * depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/SIMD, K, K, SIMD) + # * depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/PE, K, K, PE) Channels can be folded using PE (SIMD from the input perspective) TODO: doc @@ -159,7 +159,7 @@ class Pool_Batch(HLSCustomOp): def infer_node_datatype(self, model): node = self.onnx_node # data type stays the same - dtype = model.get_tensor_datatype(node.input[0]) + dtype = self.get_output_datatype() model.set_tensor_datatype(node.output[0], dtype) def verify_node(self):