Skip to content

Commit c5285cc

Browse files
From00Shixiaowei02
andauthored
Add yaml for flatten_contiguous_range OP (#41345)
* Add yaml for flatten_contiguous_range OP * update * Fix typos Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
1 parent 3152f3f commit c5285cc

File tree

12 files changed

+33
-86
lines changed

12 files changed

+33
-86
lines changed

paddle/phi/kernels/flatten_grad_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ namespace phi {
2121

2222
template <typename T, typename Context>
2323
void FlattenGradKernel(const Context& dev_ctx,
24-
const DenseTensor& out_grad,
2524
const DenseTensor& xshape,
25+
const DenseTensor& out_grad,
2626
DenseTensor* x_grad) {
2727
auto xshape_dims = xshape.dims();
2828
dev_ctx.Alloc(x_grad, out_grad.dtype());

paddle/phi/kernels/flatten_grad_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ namespace phi {
2020

2121
template <typename T, typename Context>
2222
void FlattenGradKernel(const Context& dev_ctx,
23-
const DenseTensor& out_grad,
2423
const DenseTensor& xshape,
24+
const DenseTensor& out_grad,
2525
DenseTensor* x_grad);
2626

2727
} // namespace phi

paddle/phi/ops/compat/flatten_sig.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
3131
KernelSignature FlattenGradOpArgumentMapping(
3232
const ArgumentMappingContext& ctx) {
3333
return KernelSignature(
34-
"flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")});
34+
"flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")});
3535
}
3636

3737
} // namespace phi

paddle/phi/tests/api/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ cc_test(test_dot_api SRCS test_dot_api.cc DEPS ${COMMON_API_TEST_DEPS})
1212
cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS})
1313
cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS})
1414
cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS})
15-
cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS ${COMMON_API_TEST_DEPS})
1615
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS})
1716
cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS})
1817
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS})

paddle/phi/tests/api/test_flatten_api.cc

Lines changed: 0 additions & 75 deletions
This file was deleted.

python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
class TestFlattenOp(OpTest):
2525
def setUp(self):
26+
self.python_api = paddle.flatten
27+
self.python_out_sig = ["Out"]
2628
self.op_type = "flatten_contiguous_range"
2729
self.start_axis = 0
2830
self.stop_axis = -1
@@ -35,10 +37,10 @@ def setUp(self):
3537
}
3638

3739
def test_check_output(self):
38-
self.check_output(no_check_set=["XShape"])
40+
self.check_output(no_check_set=["XShape"], check_eager=True)
3941

4042
def test_check_grad(self):
41-
self.check_grad(["X"], "Out")
43+
self.check_grad(["X"], "Out", check_eager=True)
4244

4345
def init_test_case(self):
4446
self.in_shape = (3, 2, 5, 4)

python/paddle/tensor/manipulation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,11 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
676676
if start_axis > stop_axis:
677677
raise ValueError("The stop_axis should be larger than stat_axis")
678678

679-
if paddle.in_dynamic_mode():
679+
if in_dygraph_mode():
680+
dy_out, _ = _C_ops.final_state_flatten(x, start_axis, stop_axis)
681+
return dy_out
682+
683+
if _in_legacy_dygraph():
680684
dy_out, _ = _C_ops.flatten_contiguous_range(x, 'start_axis', start_axis,
681685
'stop_axis', stop_axis)
682686
return dy_out

python/paddle/utils/code_gen/api.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,11 +547,15 @@
547547

548548
- api : flatten
549549
args : (Tensor x, int start_axis, int stop_axis)
550-
output : Tensor
550+
output : Tensor(out), Tensor(xshape)
551551
infer_meta :
552-
func : FlattenInferMeta
552+
func : FlattenWithXShapeInferMeta
553553
kernel :
554-
func : flatten
554+
func : flatten_with_xshape
555+
backend : x
556+
inplace : (x -> out)
557+
view : (x -> out)
558+
backward : flatten_grad
555559

556560
# flip
557561
- api : flip

0 commit comments

Comments
 (0)