Skip to content

Conversation

@AnnaTrainingG
Copy link
Contributor

PR types

Others

PR changes

Others

Describe

Support MaskedSelectGrad op with Kernel Primitive API

  1. 添加单线程ReadData API
  2. 添加SelectCaller 以支持masked_select_grad,通过MaskData模板参数区分MaskData = 0 where_index; MaskData = 1 masked_select; MaskData = 2 masked_select_grad
@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

int VecSize,
int IsBoundary,
int IsMaskData>
int MaskData>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议在下个PR注释里说明下 maskdata =0,1,2 对应的情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

}
}

template <typename T>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和之前的 readdata 可以复用吗 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不可以 这个是线程级别的API

#include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/masked_select_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件应该没用到,可以删掉试试

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的下个PR再修改


SelectGradWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
auto* out_data = x_grad->mutable_data<T>(dev_ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel里分配内存调用新接口:dev_ctx.template Alloc(x_grad)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的 下个PR再修改

@AnnaTrainingG AnnaTrainingG merged commit f53db25 into PaddlePaddle:develop Mar 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants