Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/fast_divmod.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ struct FastDivMod<int64_t> {
uint64_t q = Div(n);
return {q, n - q * divisor};
}
__device__ __forceinline__ uint64_t DivCeil(uint32_t n) const {
__device__ __forceinline__ uint64_t DivCeil(uint64_t n) const {
DivModT res = Divmod(n);
return res.val[1] > 0 ? res.val[0] + 1 : res.val[0];
}
Expand Down
37 changes: 25 additions & 12 deletions paddle/phi/kernels/funcs/pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include <algorithm>
#include <vector>
#include "paddle/phi/kernels/funcs/index_elementwise.cu.h"
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
Expand Down Expand Up @@ -328,15 +329,15 @@ __global__ void KernelPool2DGrad(
phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];

IndexT tmp_height, tmp_width;
IndexT pool_height, pool_width;
for (IndexT ph = phstart; ph < phend; ++ph) {
PreparationPoolSize(
ph, input_height, output_height, divmods.ksize_h, &tmp_height);
ph, input_height, output_height, divmods.ksize_h, &pool_height);

for (IndexT pw = pwstart; pw < pwend; ++pw) {
PreparationPoolSize(
pw, input_width, output_width, divmods.ksize_w, &tmp_width);
IndexT pool_size = tmp_height * tmp_width;
pw, input_width, output_width, divmods.ksize_w, &pool_width);
IndexT pool_size = pool_height * pool_width;
IndexT tmp_idx = ph * output_width + pw;
IndexT output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + c_offset
Expand Down Expand Up @@ -1178,6 +1179,9 @@ __global__ void KernelPool3DGrad(const IndexT nthreads,
IndexT pdstart, pdend;
IndexT phstart, phend;
IndexT pwstart, pwend;

IndexT pool_depth, pool_height, pool_width;

if (adaptive) {
pdstart = AdaptStartIndex(d_offset, output_depth, input_depth);
pdend = AdaptEndIndex(d_offset, output_depth, input_depth);
Expand Down Expand Up @@ -1208,19 +1212,28 @@ __global__ void KernelPool3DGrad(const IndexT nthreads,
output_grad += output_stride;
T input_grad_data = static_cast<T>(0.0);

IndexT pool_size;
for (IndexT pd = pdstart; pd < pdend; ++pd) {
for (IndexT ph = phstart; ph < phend; ++ph) {
for (IndexT pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
IndexT pool_size;
if (adaptive) {
pool_size =
static_cast<IndexT>(
ceil(static_cast<double>(input_depth) / ksize_depth)) *
static_cast<IndexT>(
ceil(static_cast<double>(input_height) / ksize_height)) *
static_cast<IndexT>(
ceil(static_cast<double>(input_width) / ksize_width));
PreparationPoolSize(pd,
input_depth,
output_depth,
FastDivMod<IndexT>(output_depth),
&pool_depth);
PreparationPoolSize(pw,
input_width,
output_width,
FastDivMod<IndexT>(output_width),
&pool_width);
PreparationPoolSize(ph,
input_height,
output_height,
FastDivMod<IndexT>(output_height),
&pool_height);
pool_size = pool_depth * pool_height * pool_width;
} else {
IndexT dstart = pd * stride_depth - padding_depth;
IndexT hstart = ph * stride_height - padding_height;
Expand Down
22 changes: 22 additions & 0 deletions test/legacy_test/test_adaptive_avg_pool3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,28 @@ def test_dynamic_graph(self):
out_6.numpy(), self.res_3_np, rtol=1e-5, atol=1e-8
)

def test_grad(self):
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_tensor(self.x_np)
x.stop_gradient = False
for output_size in [[2, 3, 5], [3, 3, 3], [6, 8, 8]]:
out = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=output_size
)
x_grad = paddle.grad(
[out],
[x],
grad_outputs=paddle.ones_like(out),
allow_unused=True,
)
np.testing.assert_allclose(
paddle.sum(x_grad[0]), out.numel(), rtol=1e-5
)


class TestAdaptiveAvgPool3DClassAPI(unittest.TestCase):
def setUp(self):
Expand Down