Skip to content

Commit 6f4bd0e

Browse files
authored
[Phi]Add graph_send_recv yaml file (#41206)
* add graph_send_recv yaml * deal with confict * fix compile bugs
1 parent 0c968b9 commit 6f4bd0e

File tree

9 files changed

+93
-21
lines changed

9 files changed

+93
-21
lines changed

paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ void GraphSendRecvGradOpKernelLaunchHelper(
118118

119119
template <typename T, typename Context>
120120
void GraphSendRecvGradKernel(const Context& ctx,
121-
const DenseTensor& out_grad,
122121
const DenseTensor& x,
123-
paddle::optional<const DenseTensor&> out,
124122
const DenseTensor& src_index,
125123
const DenseTensor& dst_index,
124+
paddle::optional<const DenseTensor&> out,
126125
paddle::optional<const DenseTensor&> dst_count,
126+
const DenseTensor& out_grad,
127127
const std::string& pool_type,
128128
DenseTensor* x_grad) {
129129
auto index_type = src_index.dtype();

paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
102102

103103
template <typename T, typename Context>
104104
void GraphSendRecvGradKernel(const Context& ctx,
105-
const DenseTensor& out_grad,
106105
const DenseTensor& x,
107-
paddle::optional<const DenseTensor&> out,
108106
const DenseTensor& src_index,
109107
const DenseTensor& dst_index,
108+
paddle::optional<const DenseTensor&> out,
110109
paddle::optional<const DenseTensor&> dst_count,
110+
const DenseTensor& out_grad,
111111
const std::string& pool_type,
112112
DenseTensor* x_grad) {
113113
auto index_type = src_index.dtype();

paddle/phi/kernels/graph_send_recv_grad_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ namespace phi {
2222

2323
template <typename T, typename Context>
2424
void GraphSendRecvGradKernel(const Context& ctx,
25-
const DenseTensor& out_grad,
2625
const DenseTensor& x,
27-
paddle::optional<const DenseTensor&> out,
2826
const DenseTensor& src_index,
2927
const DenseTensor& dst_index,
28+
paddle::optional<const DenseTensor&> out,
3029
paddle::optional<const DenseTensor&> dst_count,
30+
const DenseTensor& out_grad,
3131
const std::string& pool_type,
3232
DenseTensor* x_grad);
3333
} // namespace phi

paddle/phi/ops/compat/graph_send_recv_sig.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
2828
const ArgumentMappingContext& ctx) {
2929
return KernelSignature(
3030
"graph_send_recv_grad",
31-
{GradVarName("Out"), "X", "Out", "Src_index", "Dst_index", "Dst_count"},
31+
{"X", "Src_index", "Dst_index", "Out", "Dst_count", GradVarName("Out")},
3232
{"pool_type"},
3333
{GradVarName("X")});
3434
}

python/paddle/fluid/dygraph/tracer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
from paddle import _C_ops
2323

2424
final_state_name_mapping = {
25+
"graph_send_recv": {
26+
"final_op_name": "final_state_graph_send_recv",
27+
"x": "X",
28+
"src_index": "Src_index",
29+
"dst_index": "Dst_index",
30+
"out": "Out",
31+
"dst_count": "Dst_count"
32+
},
2533
"matmul_v2": {
2634
"final_op_name": "final_state_matmul",
2735
"transpose_x": "trans_x",

python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,26 @@
1717
import numpy as np
1818
import paddle
1919
import paddle.fluid as fluid
20+
from paddle.fluid.framework import _test_eager_guard
2021

2122
from op_test import OpTest
2223

2324

25+
def graph_send_recv_wrapper(x,
26+
src_index,
27+
dst_index,
28+
pool_type="sum",
29+
out_size=None,
30+
name=None):
31+
return paddle.incubate.graph_send_recv(x, src_index, dst_index,
32+
pool_type.lower(), out_size, name)
33+
34+
2435
class TestGraphSendRecvMaxOp(OpTest):
2536
def setUp(self):
2637
paddle.enable_static()
38+
self.python_api = graph_send_recv_wrapper
39+
self.python_out_sig = ["Out"]
2740
self.op_type = "graph_send_recv"
2841
x = np.random.random((10, 20)).astype("float64")
2942
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
@@ -39,15 +52,18 @@ def setUp(self):
3952
self.outputs = {'Out': out}
4053

4154
def test_check_output(self):
42-
self.check_output()
55+
self.check_output(check_eager=True)
4356

4457
def test_check_grad(self):
45-
self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient])
58+
self.check_grad(
59+
['X'], 'Out', user_defined_grads=[self.gradient], check_eager=True)
4660

4761

4862
class TestGraphSendRecvMinOp(OpTest):
4963
def setUp(self):
5064
paddle.enable_static()
65+
self.python_api = graph_send_recv_wrapper
66+
self.python_out_sig = ["Out"]
5167
self.op_type = "graph_send_recv"
5268
x = np.random.random((10, 20)).astype("float64")
5369
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
@@ -64,15 +80,18 @@ def setUp(self):
6480
self.outputs = {'Out': out}
6581

6682
def test_check_output(self):
67-
self.check_output()
83+
self.check_output(check_eager=True)
6884

6985
def test_check_grad(self):
70-
self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient])
86+
self.check_grad(
87+
['X'], 'Out', user_defined_grads=[self.gradient], check_eager=True)
7188

7289

7390
class TestGraphSendRecvSumOp(OpTest):
7491
def setUp(self):
7592
paddle.enable_static()
93+
self.python_api = graph_send_recv_wrapper
94+
self.python_out_sig = ["Out"]
7695
self.op_type = "graph_send_recv"
7796
x = np.random.random((10, 20)).astype("float64")
7897
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
@@ -88,15 +107,17 @@ def setUp(self):
88107
self.outputs = {'Out': out}
89108

90109
def test_check_output(self):
91-
self.check_output()
110+
self.check_output(check_eager=True)
92111

93112
def test_check_grad(self):
94-
self.check_grad(['X'], 'Out')
113+
self.check_grad(['X'], 'Out', check_eager=True)
95114

96115

97116
class TestGraphSendRecvMeanOp(OpTest):
98117
def setUp(self):
99118
paddle.enable_static()
119+
self.python_api = graph_send_recv_wrapper
120+
self.python_out_sig = ["Out"]
100121
self.op_type = "graph_send_recv"
101122
x = np.random.random((10, 20)).astype("float64")
102123
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
@@ -113,10 +134,10 @@ def setUp(self):
113134
self.outputs = {'Out': out, 'Dst_count': dst_count}
114135

115136
def test_check_output(self):
116-
self.check_output()
137+
self.check_output(check_eager=True)
117138

118139
def test_check_grad(self):
119-
self.check_grad(['X'], 'Out')
140+
self.check_grad(['X'], 'Out', check_eager=True)
120141

121142

122143
def compute_graph_send_recv_for_sum_mean(inputs, attributes):
@@ -333,6 +354,12 @@ def test_set_outsize_gpu(self):
333354
{}\n{}, check diff!"
334355
.format(np_res_set_outsize, res_set_outsize))
335356

357+
def test_api_eager_dygraph(self):
358+
with _test_eager_guard():
359+
self.test_dygraph()
360+
self.test_int32_input()
361+
self.test_set_outsize_gpu()
362+
336363

337364
if __name__ == '__main__':
338365
unittest.main()

python/paddle/incubate/operators/graph_send_recv.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from paddle.fluid.layer_helper import LayerHelper
16-
from paddle.fluid.framework import _non_static_mode
16+
from paddle.fluid.framework import _non_static_mode, _in_legacy_dygraph, in_dygraph_mode
1717
from paddle.fluid.data_feeder import check_variable_and_dtype
1818
from paddle.fluid import core
1919
from paddle import _C_ops
@@ -109,15 +109,30 @@ def graph_send_recv(x,
109109

110110
# TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1.
111111

112-
if _non_static_mode():
113-
if out_size is None or out_size <= 0:
112+
if out_size is None or out_size <= 0:
113+
if _in_legacy_dygraph():
114114
out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index,
115115
'pool_type', pool_type.upper())
116-
else:
116+
return out
117+
if in_dygraph_mode():
118+
return _C_ops.final_state_graph_send_recv(x, src_index, dst_index,
119+
pool_type.upper(), 0)
120+
else:
121+
if _in_legacy_dygraph():
117122
out, tmp = _C_ops.graph_send_recv(
118123
x, src_index, dst_index, 'pool_type',
119124
pool_type.upper(), 'out_size', out_size)
120-
return out
125+
return out
126+
if in_dygraph_mode():
127+
if isinstance(out_size, core.eager.Tensor):
128+
if (out_size.size < 1):
129+
raise ValueError(
130+
"out_size should be long type, but received Tensor type."
131+
)
132+
out_size = out_size.numpy()[0]
133+
return _C_ops.final_state_graph_send_recv(x, src_index, dst_index,
134+
pool_type.upper(),
135+
out_size)
121136

122137
check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"),
123138
"graph_send_recv")

python/paddle/utils/code_gen/api.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,17 @@
756756
func : gelu
757757
backward : gelu_grad
758758

759+
- api : graph_send_recv
760+
args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", int64_t out_size = 0)
761+
output : Tensor(out), Tensor(dst_count)
762+
infer_meta :
763+
func : GraphSendRecvInferMeta
764+
kernel :
765+
func : graph_send_recv
766+
data_type : x
767+
intermediate : dst_count
768+
backward : graph_send_recv_grad
769+
759770
- api : greater_equal
760771
args : (Tensor x, Tensor y, int axis = -1)
761772
output : Tensor
@@ -1162,7 +1173,7 @@
11621173
kernel :
11631174
func : mean_all
11641175
backward : mean_all_grad
1165-
1176+
11661177
- api : meshgrid
11671178
args : (Tensor[] inputs)
11681179
output : Tensor[]

python/paddle/utils/code_gen/backward.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,17 @@
537537
kernel :
538538
func : gelu_grad
539539

540+
- backward_api : graph_send_recv_grad
541+
forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", int64_t out_size = 0) -> Tensor(out), Tensor(dst_count)
542+
args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type = "SUM")
543+
output : Tensor(x_grad)
544+
infer_meta :
545+
func : GeneralUnaryGradInferMeta
546+
param : [x]
547+
kernel :
548+
func : graph_send_recv_grad
549+
optional: out, dst_count
550+
540551
- backward_api : hard_shrink_grad
541552
forward : hard_shrink (Tensor x, float threshold) -> Tensor(out)
542553
args : (Tensor x, Tensor out_grad, float threshold)

0 commit comments

Comments
 (0)