Skip to content
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
'c_identity',
'c_reduce_sum',
'c_reducescatter',
'c_softmax_with_cross_entropy',
'decayed_adagrad',
'dpsgd',
'embedding_grad_sparse',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,16 @@
output : Tensor(out)
invoke : full_like(x, 0, dtype, place)

- op: c_softmax_with_cross_entropy
args: (Tensor logits, Tensor label, int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0)
output: Tensor(softmax), Tensor(loss)
infer_meta:
func : CSoftmaxWithCrossEntropyInferMeta
kernel:
func: c_softmax_with_cross_entropy
data_type : logits
Comment on lines +1402 to +1409
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不需要配置backward吗?

backward: c_softmax_with_cross_entropy_grad

- op: dpsgd
args: (Tensor param, Tensor grad, Tensor learning_rate, float clip = 10.0f, float batch_size = 16.0f, float sigma = 1.0f, int seed = 0)
output: Tensor(param_out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@
func : c_embedding_grad
no_need_buffer : weight

- backward_op : c_softmax_with_cross_entropy_grad
forward: c_softmax_with_cross_entropy (Tensor logits, Tensor label, int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0) -> Tensor(softmax), Tensor(loss)
args: (Tensor softmax, Tensor label, Tensor loss_grad,int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0)
output: Tensor(logits_grad)
infer_meta :
func: CSoftmaxWithCrossEntropyGradInferMeta
kernel:
func: c_softmax_with_cross_entropy_grad
data_type: loss_grad

- backward_op : cast_grad
forward : cast (Tensor x, DataType dtype) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ const std::unordered_set<std::string> LegacyOpList = {
CReduceSum_Op::name(),
CAllreduceMax_Op::name(),
CAllgatherOp::name(),
CSoftmaxWithCrossEntropyOp::name(),
CSoftmaxWithCrossEntropyGradOp::name(),
SeedOp::name(),
ShareDataOp::name(),
SparseMomentumOp::name(),
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,13 @@
outputs :
out : Out

- op : c_softmax_with_cross_entropy
backward : c_softmax_with_cross_entropy_grad
inputs :
{logits : Logits, label : Label}
outputs :
{softmax : Softmax, loss : Loss}

- op : cast
inputs :
x : X
Expand Down
12 changes: 11 additions & 1 deletion paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,17 @@ void CropGradInferMeta(const MetaTensor& out_grad,
x_grad->set_dtype(x.dtype());
}
}

void CSoftmaxWithCrossEntropyGradInferMeta(const MetaTensor& softmax,
const MetaTensor& label,
const MetaTensor& loss_grad,
int64_t ignore_index,
int ring_id,
int rank,
int nranks,
MetaTensor* logits_grad,
MetaConfig config) {
logits_grad->set_dims(softmax.dims());
}
void FlashAttnGradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
MetaTensor* logits_grad,
MetaConfig config = MetaConfig());

void CSoftmaxWithCrossEntropyGradInferMeta(const MetaTensor& softmax,
const MetaTensor& label,
const MetaTensor& loss_grad,
int64_t ignore_index,
int ring_id,
int rank,
int nranks,
MetaTensor* logits_grad,
MetaConfig config = MetaConfig());

void DeformableConvGradInferMeta(const MetaTensor& x,
const MetaTensor& offset,
const MetaTensor& filter,
Expand Down
43 changes: 43 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,49 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
loss->share_lod(logits);
}

void CSoftmaxWithCrossEntropyInferMeta(const MetaTensor& logits,
const MetaTensor& label,
int64_t ignore_index,
int ring_id,
int rank,
int nranks,
MetaTensor* softmax,
MetaTensor* loss,
MetaConfig config) {
auto logits_dims = logits.dims();
auto labels_dims = label.dims();

auto logits_rank = logits_dims.size();
auto axis = logits_rank - 1;
for (int i = 0; i < logits_rank; i++) {
if (i != axis) {
if (config.is_runtime || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[i],
labels_dims[i],
phi::errors::InvalidArgument(
"Input(Logits) and Input(Label) should in "
"same shape in dimensions except axis."));
}
}
}

PADDLE_ENFORCE_EQ(
labels_dims[logits_rank - 1],
1UL,
phi::errors::InvalidArgument(
"the last dimension of Input(Label) should be 1."
"But received: the last dimension of Input(Label) is [%d],"
"the last dimension is [%d]",
labels_dims[logits_rank - 1],
logits_rank - 1));

softmax->set_dims(logits_dims);
logits_dims[axis] = 1;
loss->set_dims(logits_dims);
softmax->share_lod(logits);
loss->share_lod(logits);
}

void DepthwiseConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
MetaTensor* loss,
MetaConfig config = MetaConfig());

void CSoftmaxWithCrossEntropyInferMeta(const MetaTensor& logits,
const MetaTensor& label,
int64_t ignore_index,
int ring_id,
int rank,
int nranks,
MetaTensor* softmax,
MetaTensor* loss,
MetaConfig config = MetaConfig());

void DepthwiseConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def test_mp(self):
"semi_auto_parallel_c_cross_entropy_mp.py",
)

def test_mp_pir(self):
os.environ["FLAGS_enable_pir_in_executor"] = "True"
self.test_mp()
os.environ["FLAGS_enable_pir_in_executor"] = "False"


class TestParallelCrossEntropyHybrid(test_base.CommunicationTestDistBase):
def setUp(self):
Expand Down