Skip to content

Commit de7b288

Browse files
authored
【PIR OpTest Fix No.38】 fix test_semi_auto_parallel_c_cross_entropy (#59893)
* register c_softmax * register c_softmax * Update ops_backward.yaml * Update utils.cc * add test_semi_auto_parallel_c_cross_entropy to whitelist * Revert "add test_semi_auto_parallel_c_cross_entropy to whitelist" This reverts commit 75b3605. * add pit test * Update ops.yaml
1 parent 269dedd commit de7b288

File tree

10 files changed

+109
-1
lines changed

10 files changed

+109
-1
lines changed

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
'c_identity',
116116
'c_reduce_sum',
117117
'c_reducescatter',
118+
'c_softmax_with_cross_entropy',
118119
'decayed_adagrad',
119120
'dpsgd',
120121
'embedding_grad_sparse',

paddle/fluid/pir/dialect/operator/ir/ops.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,16 @@
13911391
output : Tensor(out)
13921392
invoke : full_like(x, 0, dtype, place)
13931393

1394+
- op: c_softmax_with_cross_entropy
1395+
args: (Tensor logits, Tensor label, int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0)
1396+
output: Tensor(softmax), Tensor(loss)
1397+
infer_meta:
1398+
func : CSoftmaxWithCrossEntropyInferMeta
1399+
kernel:
1400+
func: c_softmax_with_cross_entropy
1401+
data_type : logits
1402+
backward: c_softmax_with_cross_entropy_grad
1403+
13941404
- op: dpsgd
13951405
args: (Tensor param, Tensor grad, Tensor learning_rate, float clip = 10.0f, float batch_size = 16.0f, float sigma = 1.0f, int seed = 0)
13961406
output: Tensor(param_out)

paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@
119119
func : c_embedding_grad
120120
no_need_buffer : weight
121121

122+
- backward_op : c_softmax_with_cross_entropy_grad
123+
forward: c_softmax_with_cross_entropy (Tensor logits, Tensor label, int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0) -> Tensor(softmax), Tensor(loss)
124+
args: (Tensor softmax, Tensor label, Tensor loss_grad,int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0)
125+
output: Tensor(logits_grad)
126+
infer_meta :
127+
func: CSoftmaxWithCrossEntropyGradInferMeta
128+
kernel:
129+
func: c_softmax_with_cross_entropy_grad
130+
data_type: loss_grad
131+
122132
- backward_op : cast_grad
123133
forward : cast (Tensor x, DataType dtype) -> Tensor(out)
124134
args : (Tensor x, Tensor out_grad)

paddle/fluid/pir/dialect/operator/utils/utils.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ const std::unordered_set<std::string> LegacyOpList = {
4949
CReduceSum_Op::name(),
5050
CAllreduceMax_Op::name(),
5151
CAllgatherOp::name(),
52+
CSoftmaxWithCrossEntropyOp::name(),
53+
CSoftmaxWithCrossEntropyGradOp::name(),
5254
SeedOp::name(),
5355
ShareDataOp::name(),
5456
SparseMomentumOp::name(),

paddle/phi/api/yaml/op_compat.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,13 @@
452452
outputs :
453453
out : Out
454454

455+
- op : c_softmax_with_cross_entropy
456+
backward : c_softmax_with_cross_entropy_grad
457+
inputs :
458+
{logits : Logits, label : Label}
459+
outputs :
460+
{softmax : Softmax, loss : Loss}
461+
455462
- op : cast
456463
inputs :
457464
x : X

paddle/phi/infermeta/backward.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,17 @@ void CropGradInferMeta(const MetaTensor& out_grad,
201201
x_grad->set_dtype(x.dtype());
202202
}
203203
}
204-
204+
void CSoftmaxWithCrossEntropyGradInferMeta(const MetaTensor& softmax,
205+
const MetaTensor& label,
206+
const MetaTensor& loss_grad,
207+
int64_t ignore_index,
208+
int ring_id,
209+
int rank,
210+
int nranks,
211+
MetaTensor* logits_grad,
212+
MetaConfig config) {
213+
logits_grad->set_dims(softmax.dims());
214+
}
205215
void FlashAttnGradInferMeta(const MetaTensor& q,
206216
const MetaTensor& k,
207217
const MetaTensor& v,

paddle/phi/infermeta/backward.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,16 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
123123
MetaTensor* logits_grad,
124124
MetaConfig config = MetaConfig());
125125

126+
void CSoftmaxWithCrossEntropyGradInferMeta(const MetaTensor& softmax,
127+
const MetaTensor& label,
128+
const MetaTensor& loss_grad,
129+
int64_t ignore_index,
130+
int ring_id,
131+
int rank,
132+
int nranks,
133+
MetaTensor* logits_grad,
134+
MetaConfig config = MetaConfig());
135+
126136
void DeformableConvGradInferMeta(const MetaTensor& x,
127137
const MetaTensor& offset,
128138
const MetaTensor& filter,

paddle/phi/infermeta/binary.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,49 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
10091009
loss->share_lod(logits);
10101010
}
10111011

1012+
void CSoftmaxWithCrossEntropyInferMeta(const MetaTensor& logits,
1013+
const MetaTensor& label,
1014+
int64_t ignore_index,
1015+
int ring_id,
1016+
int rank,
1017+
int nranks,
1018+
MetaTensor* softmax,
1019+
MetaTensor* loss,
1020+
MetaConfig config) {
1021+
auto logits_dims = logits.dims();
1022+
auto labels_dims = label.dims();
1023+
1024+
auto logits_rank = logits_dims.size();
1025+
auto axis = logits_rank - 1;
1026+
for (int i = 0; i < logits_rank; i++) {
1027+
if (i != axis) {
1028+
if (config.is_runtime || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
1029+
PADDLE_ENFORCE_EQ(logits_dims[i],
1030+
labels_dims[i],
1031+
phi::errors::InvalidArgument(
1032+
"Input(Logits) and Input(Label) should in "
1033+
"same shape in dimensions except axis."));
1034+
}
1035+
}
1036+
}
1037+
1038+
PADDLE_ENFORCE_EQ(
1039+
labels_dims[logits_rank - 1],
1040+
1UL,
1041+
phi::errors::InvalidArgument(
1042+
"the last dimension of Input(Label) should be 1."
1043+
"But received: the last dimension of Input(Label) is [%d],"
1044+
"the last dimension is [%d]",
1045+
labels_dims[logits_rank - 1],
1046+
logits_rank - 1));
1047+
1048+
softmax->set_dims(logits_dims);
1049+
logits_dims[axis] = 1;
1050+
loss->set_dims(logits_dims);
1051+
softmax->share_lod(logits);
1052+
loss->share_lod(logits);
1053+
}
1054+
10121055
void DepthwiseConvInferMeta(const MetaTensor& input,
10131056
const MetaTensor& filter,
10141057
const std::vector<int>& strides,

paddle/phi/infermeta/binary.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
159159
MetaTensor* loss,
160160
MetaConfig config = MetaConfig());
161161

162+
void CSoftmaxWithCrossEntropyInferMeta(const MetaTensor& logits,
163+
const MetaTensor& label,
164+
int64_t ignore_index,
165+
int ring_id,
166+
int rank,
167+
int nranks,
168+
MetaTensor* softmax,
169+
MetaTensor* loss,
170+
MetaConfig config = MetaConfig());
171+
162172
void DepthwiseConvInferMeta(const MetaTensor& input,
163173
const MetaTensor& filter,
164174
const std::vector<int>& strides,

test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_c_cross_entropy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def test_mp(self):
3535
"semi_auto_parallel_c_cross_entropy_mp.py",
3636
)
3737

38+
def test_mp_pir(self):
39+
os.environ["FLAGS_enable_pir_in_executor"] = "True"
40+
self.test_mp()
41+
os.environ["FLAGS_enable_pir_in_executor"] = "False"
42+
3843

3944
class TestParallelCrossEntropyHybrid(test_base.CommunicationTestDistBase):
4045
def setUp(self):

0 commit comments

Comments
 (0)