Skip to content

Conversation

@phlrain
Copy link
Collaborator

@phlrain phlrain commented Mar 3, 2022

PR types

Breaking changes

PR changes

OPs

Describe

move dropout to phi

@paddle-bot-old
Copy link

paddle-bot-old bot commented Mar 3, 2022

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

namespace phi {

template <typename T, typename Context>
void DropoutGradRawKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad是不是不需要raw kernel,就一个

template <typename T, typename Context>
void DropoutRawKernel(const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> seed_tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed tensor的下面的seed应该用Scalar统一表示?而不是写两个参数?

bool is_test,
const std::string& mode,
DenseTensor* x_grad) {
x_grad->mutable_data<T>(dev_ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutable_data -> Alloc

bool fix_seed,
DenseTensor* out,
DenseTensor* mask) {
out->mutable_data<T>(dev_ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"dropout",
{"X", "Seed"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要有if分支选择seed

DenseTensor* x_grad) {
auto* grad_x = x_grad;
auto* grad_y = &out_grad;
grad_x->mutable_data<T>(dev_ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dev_ctx.Alloc


#pragma once

#include "paddle/phi/common/scalar.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalar没有用到

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall,细节问题后续PR再完善一下

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@phlrain phlrain merged commit 99fc1b0 into PaddlePaddle:develop Mar 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants