Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions paddle/operators/scaling_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/scaling_op.h"

namespace paddle {
namespace operators {

class ScalingOp : public framework::OperatorWithKernel {
public:
ScalingOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *in = ctx.Input<framework::Tensor>("X");
auto *weight = ctx.Input<framework::Tensor>("weight");
PADDLE_ENFORCE_EQ(1, weight->dims().size(),
"The Input(weight) must be a vector");
PADDLE_ENFORCE_EQ(2, in->dims().size(), "The Input(X) must be a matrix.");
PADDLE_ENFORCE_EQ(in->dims()[0], weight->dims()[0],
"The rows' number of Input(X) must be equal to the size"
" of Input(weight).");
auto *out = ctx.Output<framework::Tensor>("Out");
out->Resize(in->dims());
}
};

class ScalingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScalingOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scaling operator.");
AddInput("weight", "The weight vector of scaling operator.");
AddOutput("Out", "The output tensor of scaling operator.");
AddComment(R"DOC(Scaling operator

The equation is: Out.row[i] = weight[i] * X.row[i]
)DOC");
}
};

class ScalingGradOp : public framework::OperatorWithKernel {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gradient operator could be composed by two forward operator. See minus_op.

public:
ScalingGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto in_dims = ctx.Input<framework::Tensor>("X")->dims();
auto weight_dims = ctx.Input<framework::Tensor>("weight")->dims();
ctx.Output<framework::Tensor>(framework::GradVarName("X"))->Resize(in_dims);
ctx.Output<framework::Tensor>(framework::GradVarName("weight"))
->Resize(weight_dims);
}
};

} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;

REGISTER_OP(scaling, ops::ScalingOp, ops::ScalingOpMaker, scaling_grad,
ops::ScalingGradOp);
REGISTER_OP_CPU_KERNEL(scaling,
ops::ScalingKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
scaling_grad, ops::ScalingGradKernel<paddle::platform::CPUPlace, float>);
22 changes: 22 additions & 0 deletions paddle/operators/scaling_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/scaling_op.h"

REGISTER_OP_GPU_KERNEL(
scaling,
paddle::operators::ScalingKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
scaling_grad,
paddle::operators::ScalingGradKernel<paddle::platform::GPUPlace, float>);
85 changes: 85 additions & 0 deletions paddle/operators/scaling_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename Place, typename T>
class ScalingKernel : public framework::OpKernel {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<Tensor>("Out");
auto* in = ctx.Input<Tensor>("X");
auto* weight = ctx.Input<Tensor>("weight");
out->mutable_data<T>(in->place());

auto rows = in->dims()[0];
auto cols = in->dims()[1];

auto col_vec_dims = framework::make_ddim({rows, 1});
auto bd_cast_dims = framework::make_ddim({1, cols});

auto eigen_out = framework::EigenMatrix<T>::From(*out);
auto eigen_in = framework::EigenMatrix<T>::From(*in);
auto eigen_weight = framework::EigenMatrix<T>::From(*weight, col_vec_dims);

auto& dev = ctx.GetEigenDevice<Place>();
eigen_out.device(dev) = eigen_in * eigen_weight.broadcast(bd_cast_dims);
}
};

template <typename Place, typename T>
class ScalingGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_in = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_weight = ctx.Output<Tensor>(framework::GradVarName("weight"));

auto* in = ctx.Input<Tensor>("X");
auto* weight = ctx.Input<Tensor>("weight");

d_in->mutable_data<T>(ctx.GetPlace());
d_weight->mutable_data<T>(ctx.GetPlace());

auto rows = d_out->dims()[0];
auto cols = d_out->dims()[1];
auto col_vec_dims = framework::make_ddim({rows, 1});
auto bd_cast_dims = framework::make_ddim({1, cols});

auto eigen_in = framework::EigenMatrix<T>::From(*in);
auto eigen_weight = framework::EigenMatrix<T>::From(*weight, col_vec_dims);

auto eigen_d_out = framework::EigenMatrix<T>::From(*d_out);
auto eigen_d_in = framework::EigenMatrix<T>::From(*d_in);
auto eigen_d_weight = framework::EigenVector<T>::From(*d_weight);

auto& dev = ctx.GetEigenDevice<Place>();
// dX = dOut * weight.broadcast()
eigen_d_in.device(dev) = eigen_d_out * eigen_weight.broadcast(bd_cast_dims);

Eigen::array<int, 1> dims{{1}};
// d_weight = dOut * X, reduce to one column
eigen_d_weight.device(dev) = (eigen_d_out * eigen_in).sum(dims);
}
};
} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_OP(lookup_table);
USE_OP(scale);
USE_OP(scaling);
USE_NO_KERNEL_OP(identity);
USE_OP(minus);
USE_OP(cos_sim);
Expand Down
1 change: 1 addition & 0 deletions python/paddle/v2/framework/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
py_test(test_lookup_table SRCS test_lookup_table.py)
py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
py_test(test_scaling_op SRCS test_scaling_op.py)
py_test(mnist SRCS mnist.py)
36 changes: 36 additions & 0 deletions python/paddle/v2/framework/tests/test_scaling_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
import numpy as np
from paddle.v2.framework.op import Operator


class TestScalingOp(unittest.TestCase):
__metaclass__ = OpTestMeta

def setUp(self):
self.type = "scaling"
self.inputs = {
'X': np.random.random((32, 64)).astype("float32"),
'weight': np.random.random(32).astype("float32")
}
self.outputs = {
'Out': np.dot(np.diag(self.inputs['weight']), self.inputs['X'])
}


class ScalingGradOp(GradientChecker):
def test_scaling(self):
op = create_op("scaling")
inputs = {
'X': np.random.random((32, 64)).astype("float32"),
'weight': np.random.random(32).astype("float32")
}
self.check_grad(
op, inputs, set(['X', "weight"]), "Out", max_relative_error=0.5)


if __name__ == '__main__':
unittest.main()
if __name__ == '__main__':
Copy link
Collaborator

@reyoung reyoung Sep 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe __main__ should not be defined twice.

unittest.main()