File tree Expand file tree Collapse file tree 2 files changed +33
-1
lines changed 
python/paddle/nn/functional Expand file tree Collapse file tree 2 files changed +33
-1
lines changed Original file line number Diff line number Diff line change @@ -3071,6 +3071,13 @@ def cross_entropy(
30713071 # so, reduce_sum all directly is ok 
30723072 return  _C_ops .sum (out , [], None , False )
30733073 elif  reduction  ==  "mean" :
3074+  # when reduction is mean, use paddle.nan 
3075+  def  _replace_nan (out ):
3076+  return  out  +  paddle .nan 
3077+ 
3078+  if  0  in  input .shape :
3079+  out  =  _replace_nan (out )
3080+  return  _C_ops .mean_all (out )
30743081 # 1. if weight==none, 
30753082 # numerator: reduce_sum all loss directly is ok causeof base_softmax_with_cross_entropy's inner logic 
30763083 # denominator: count sample num with class_index!=ignore_index 
Original file line number Diff line number Diff line change 1515import  unittest 
1616
1717import  numpy  as  np 
18- from  op_test  import  OpTest , paddle_static_guard , randomize_probability 
18+ from  op_test  import  (
19+  OpTest ,
20+  get_places ,
21+  paddle_static_guard ,
22+  randomize_probability ,
23+ )
1924
2025import  paddle 
2126from  paddle  import  base 
@@ -497,5 +502,25 @@ def get_cross_entropy(self):
497502 self .cross_entropy  =  np .random .random ([0 , 1 ]).astype (np .float64 )
498503
499504
505+ class  TestCrossEntropyOp_ZeroSize2 (unittest .TestCase ):
506+  def  test_dygraph_api (self ):
507+  for  place  in  get_places ():
508+  paddle .disable_static (place )
509+  x_np  =  np .random .random ((16 , 0 )).astype (np .float64 )
510+  label_np  =  np .random .random ((16 , 0 )).astype (np .float64 )
511+  x  =  paddle .to_tensor (x_np )
512+  x .stop_gradient  =  False 
513+  label  =  paddle .to_tensor (label_np )
514+  label .stop_gradient  =  False 
515+  out1  =  paddle .nn .functional .cross_entropy (
516+  x , label , soft_label = True , reduction = "mean" 
517+  )
518+  out2  =  np .array (np .nan ).astype (np .float64 )
519+  np .testing .assert_allclose (out1 .numpy (), out2 )
520+  paddle .sum (out1 ).backward ()
521+  np .testing .assert_allclose (x .grad .shape , x .shape )
522+  paddle .enable_static ()
523+ 
524+ 
500525if  __name__  ==  "__main__" :
501526 unittest .main ()
                         You can’t perform that action at this time. 
           
                  
0 commit comments