|  | 
|  | 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | 
|  | 2 | +
 | 
|  | 3 | + Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | + you may not use this file except in compliance with the License. | 
|  | 5 | + You may obtain a copy of the License at | 
|  | 6 | +
 | 
|  | 7 | + http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +
 | 
|  | 9 | + Unless required by applicable law or agreed to in writing, software | 
|  | 10 | + distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | + See the License for the specific language governing permissions and | 
|  | 13 | + limitations under the License. */ | 
|  | 14 | + | 
|  | 15 | +#include "paddle/operators/dropout_op.h" | 
|  | 16 | + | 
|  | 17 | +namespace paddle { | 
|  | 18 | +namespace operators { | 
|  | 19 | + | 
|  | 20 | +using framework::Tensor; | 
|  | 21 | +using framework::LoDTensor; | 
|  | 22 | + | 
|  | 23 | +class DropoutOp : public framework::OperatorWithKernel { | 
|  | 24 | + public: | 
|  | 25 | + using framework::OperatorWithKernel::OperatorWithKernel; | 
|  | 26 | + | 
|  | 27 | + protected: | 
|  | 28 | + void InferShape(const framework::InferShapeContext &ctx) const override { | 
|  | 29 | + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); | 
|  | 30 | + PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0); | 
|  | 31 | + PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1); | 
|  | 32 | + // TODO(xinghai-sun): remove this check after swtiching to bool | 
|  | 33 | + PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 || | 
|  | 34 | + ctx.Attr<int>("is_training") == 1); | 
|  | 35 | + | 
|  | 36 | + auto dims = ctx.Input<Tensor>("X")->dims(); | 
|  | 37 | + ctx.Output<LoDTensor>("Out")->Resize(dims); | 
|  | 38 | + if (ctx.Attr<int>("is_training") == 1) { | 
|  | 39 | + ctx.Output<LoDTensor>("Mask")->Resize(dims); | 
|  | 40 | + } | 
|  | 41 | + } | 
|  | 42 | +}; | 
|  | 43 | + | 
|  | 44 | +template <typename AttrType> | 
|  | 45 | +class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { | 
|  | 46 | + public: | 
|  | 47 | + DropoutOpMaker(framework::OpProto *proto, | 
|  | 48 | + framework::OpAttrChecker *op_checker) | 
|  | 49 | + : OpProtoAndCheckerMaker(proto, op_checker) { | 
|  | 50 | + AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.") | 
|  | 51 | + .SetDefault(.5f); | 
|  | 52 | + // TODO(xinghai-sun): use bool for is_training after bool is supported. | 
|  | 53 | + AddAttr<int>("is_training", "Whether in training phase.").SetDefault(1); | 
|  | 54 | + AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); | 
|  | 55 | + AddInput("X", "The input of dropout op."); | 
|  | 56 | + AddOutput("Out", "The output of dropout op."); | 
|  | 57 | + AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate(); | 
|  | 58 | + | 
|  | 59 | + AddComment(R"DOC( | 
|  | 60 | +Dropout Operator. | 
|  | 61 | +
 | 
|  | 62 | +"Dropout" refers to randomly dropping out units in a nerual network. It is a | 
|  | 63 | +regularization technique for reducing overfitting by preventing neuron | 
|  | 64 | +co-adaption during training. The dropout operator randomly set (according to | 
|  | 65 | +the given dropout probability) the outputs of some units to zero, while others | 
|  | 66 | +being set to their inputs. | 
|  | 67 | +)DOC"); | 
|  | 68 | + } | 
|  | 69 | +}; | 
|  | 70 | + | 
|  | 71 | +template <typename AttrType> | 
|  | 72 | +class DropoutOpGrad : public framework::OperatorWithKernel { | 
|  | 73 | + public: | 
|  | 74 | + using framework::OperatorWithKernel::OperatorWithKernel; | 
|  | 75 | + | 
|  | 76 | + protected: | 
|  | 77 | + void InferShape(const framework::InferShapeContext &ctx) const override { | 
|  | 78 | + PADDLE_ENFORCE_EQ(ctx.Attr<int>("is_training"), 1, | 
|  | 79 | + "GradOp is only callable when is_training is true"); | 
|  | 80 | + | 
|  | 81 | + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); | 
|  | 82 | + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); | 
|  | 83 | + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), | 
|  | 84 | + "Input(Out@GRAD) must not be null."); | 
|  | 85 | + | 
|  | 86 | + PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0); | 
|  | 87 | + PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1); | 
|  | 88 | + // TODO(xinghai-sun): remove this check after swtiching to bool | 
|  | 89 | + PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 || | 
|  | 90 | + ctx.Attr<int>("is_training") == 1); | 
|  | 91 | + auto x_dims = ctx.Input<Tensor>("X")->dims(); | 
|  | 92 | + auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); | 
|  | 93 | + PADDLE_ENFORCE_EQ(x_dims, out_dims, | 
|  | 94 | + "Dimensions of Input(X) and Out@Grad must be the same."); | 
|  | 95 | + auto mask_dims = ctx.Input<Tensor>("Mask")->dims(); | 
|  | 96 | + PADDLE_ENFORCE_EQ(x_dims, mask_dims, | 
|  | 97 | + "Dimensions of Input(X) and Mask must be the same."); | 
|  | 98 | + | 
|  | 99 | + auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); | 
|  | 100 | + x_grad->Resize(x_dims); | 
|  | 101 | + } | 
|  | 102 | +}; | 
|  | 103 | + | 
|  | 104 | +} // namespace operators | 
|  | 105 | +} // namespace paddle | 
|  | 106 | + | 
|  | 107 | +namespace ops = paddle::operators; | 
|  | 108 | +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad, | 
|  | 109 | + ops::DropoutOpGrad<float>); | 
|  | 110 | +REGISTER_OP_CPU_KERNEL( | 
|  | 111 | + dropout, ops::CPUDropoutKernel<paddle::platform::CPUPlace, float, float>); | 
|  | 112 | +REGISTER_OP_CPU_KERNEL( | 
|  | 113 | + dropout_grad, ops::DropoutGradKernel<paddle::platform::CPUPlace, float>); | 
0 commit comments