@@ -1009,6 +1009,49 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
10091009 loss->share_lod (logits);
10101010}
10111011
1012+ void CSoftmaxWithCrossEntropyInferMeta (const MetaTensor& logits,
1013+ const MetaTensor& label,
1014+ int64_t ignore_index,
1015+ int ring_id,
1016+ int rank,
1017+ int nranks,
1018+ MetaTensor* softmax,
1019+ MetaTensor* loss,
1020+ MetaConfig config) {
1021+ auto logits_dims = logits.dims ();
1022+ auto labels_dims = label.dims ();
1023+
1024+ auto logits_rank = logits_dims.size ();
1025+ auto axis = logits_rank - 1 ;
1026+ for (int i = 0 ; i < logits_rank; i++) {
1027+ if (i != axis) {
1028+ if (config.is_runtime || (logits_dims[i] > 0 && labels_dims[i] > 0 )) {
1029+ PADDLE_ENFORCE_EQ (logits_dims[i],
1030+ labels_dims[i],
1031+ phi::errors::InvalidArgument (
1032+ " Input(Logits) and Input(Label) should in "
1033+ " same shape in dimensions except axis." ));
1034+ }
1035+ }
1036+ }
1037+
1038+ PADDLE_ENFORCE_EQ (
1039+ labels_dims[logits_rank - 1 ],
1040+ 1UL ,
1041+ phi::errors::InvalidArgument (
1042+ " the last dimension of Input(Label) should be 1."
1043+ " But received: the last dimension of Input(Label) is [%d],"
1044+ " the last dimension is [%d]" ,
1045+ labels_dims[logits_rank - 1 ],
1046+ logits_rank - 1 ));
1047+
1048+ softmax->set_dims (logits_dims);
1049+ logits_dims[axis] = 1 ;
1050+ loss->set_dims (logits_dims);
1051+ softmax->share_lod (logits);
1052+ loss->share_lod (logits);
1053+ }
1054+
10121055void DepthwiseConvInferMeta (const MetaTensor& input,
10131056 const MetaTensor& filter,
10141057 const std::vector<int >& strides,
0 commit comments