Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/roll_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class RollOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"Output(Out) of RollOp should not be null."));

auto dims = ctx->Attrs().Get<std::vector<int64_t>>("dims");
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");

PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
Expand Down Expand Up @@ -92,7 +92,7 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"of the tensor are shifted.")
.SetDefault({});
AddAttr<std::vector<int64_t>>(
"dims",
"axis",
"Axis along which to roll. It must have the same size "
"with shifts.")
.SetDefault({});
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/roll_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class RollKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

std::vector<T> out_vec;
TensorToVector(input, context.device_context(), &out_vec);
Expand All @@ -94,8 +94,8 @@ class RollKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dims[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dims[%d]) = %d.",
"Attr(axis[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis[%d]) = %d.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python端也增加对axis范围的检查,防止打印出来call stack。

i, input_dim.size(), input_dim.size() - 1, i, dims[i]));
shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]);
}
Expand All @@ -114,7 +114,7 @@ class RollGradKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

std::vector<T> out_vec;
TensorToVector(input, context.device_context(), &out_vec);
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/fluid/tests/unittests/test_roll_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def setUp(self):
self.op_type = "roll"
self.init_dtype_type()
self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)}
self.attrs = {'shifts': self.shifts, 'dims': self.dims}
self.attrs = {'shifts': self.shifts, 'axis': self.axis}
self.outputs = {
'Out': np.roll(self.inputs['X'], self.attrs['shifts'],
self.attrs['dims'])
self.attrs['axis'])
}

def init_dtype_type(self):
self.dtype = np.float64
self.x_shape = (100, 4, 5)
self.shifts = [101, -1]
self.dims = [0, -2]
self.axis = [0, -2]

def test_check_output(self):
self.check_output()
Expand All @@ -52,7 +52,7 @@ def init_dtype_type(self):
self.dtype = np.float32
self.x_shape = (100, 10, 5)
self.shifts = [8, -1]
self.dims = [-1, -2]
self.axis = [-1, -2]


class TestRollAPI(unittest.TestCase):
Expand All @@ -78,7 +78,7 @@ def test_roll_op_api(self):
# case 2:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1, dims=0)
z = paddle.roll(x, shifts=1, axis=0)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
Expand All @@ -101,7 +101,7 @@ def test_dygraph_api(self):
# case 2:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
z = paddle.roll(x, shifts=1, dims=0)
z = paddle.roll(x, shifts=1, axis=0)
np_z = z.numpy()
expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
Expand Down
26 changes: 13 additions & 13 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def flip(input, dims, name=None):
return out


def roll(input, shifts, dims=None):
def roll(input, shifts, axis=None, name=None):
"""
Copy link
Contributor

@jzhang533 jzhang533 Jul 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be : paddle.tensor.roll(x, shifts, axis=None, name=None)
according to the latest argument convention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

:alias_main: paddle.roll
:alias: paddle.roll,paddle.tensor.roll,paddle.tensor.manipulation.roll
Expand All @@ -117,7 +117,7 @@ def roll(input, shifts, dims=None):
input (Variable): The input tensor variable.
shifts (int|list|tuple): The number of places by which the elements
of the `input` tensor are shifted.
dims (int|list|tuple|None): Dimentions along which to roll.
axis (int|list|tuple|None): Dimentions along which to roll.
Copy link
Contributor

@jzhang533 jzhang533 Jul 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimensions -> axis
上面的简介部分:``Roll the input tensor along ...'' 需要换一种说法, character level copy是不可以的。
示例代码中用paddle.enable_imperative()开启动态图模式,方便未来默认动态图的时候统一调整示例代码。


Returns:
Variable: A Tensor with same data type as `input`.
Expand All @@ -138,7 +138,7 @@ def roll(input, shifts, dims=None):
#[[9. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
out_z2 = paddle.roll(x, shifts=1, dims=0)
out_z2 = paddle.roll(x, shifts=1, axis=0)
print(out_z2.numpy())
#[[7. 8. 9.]
# [1. 2. 3.]
Expand All @@ -148,31 +148,31 @@ def roll(input, shifts, dims=None):
origin_shape = input.shape
if type(shifts) == int:
shifts = [shifts]
if type(dims) == int:
dims = [dims]
if type(axis) == int:
axis = [axis]

if dims:
check_type(dims, 'dims', (list, tuple), 'roll')
if axis:
check_type(axis, 'axis', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')

if in_dygraph_mode():
if dims is None:
if axis is None:
input = core.ops.reshape(input, 'shape', [-1, 1])
dims = [0]
out = core.ops.roll(input, 'dims', dims, 'shifts', shifts)
axis = [0]
out = core.ops.roll(input, 'axis', axis, 'shifts', shifts)
return core.ops.reshape(out, 'shape', origin_shape)

out = helper.create_variable_for_type_inference(input.dtype)

if dims is None:
if axis is None:
input = reshape(input, shape=[-1, 1])
dims = [0]
axis = [0]

helper.append_op(
type='roll',
inputs={'X': input},
outputs={'Out': out},
attrs={'dims': dims,
attrs={'axis': axis,
'shifts': shifts})
out = reshape(out, shape=origin_shape, inplace=True)
return out
Expand Down