@@ -40,12 +40,12 @@ def backward(ctx, dy):
4040 input_ , weight = ctx .saved_tensor ()
4141 out_grad = dy
4242 sub_grad = out_grad .reshape ([- 1 , out_grad .shape [- 1 ]])
43- input_grad = paddle .matmul (sub_grad , weight . t () )
43+ input_grad = paddle .matmul (sub_grad , weight , transpose_y = True )
4444 if weight .stop_gradient :
4545 return input_grad .reshape (input_ .shape )
4646 else :
4747 input_reshape = input_ .reshape ([- 1 , input_ .shape [- 1 ]])
48- weight_grad = input_reshape . t (). matmul (sub_grad )
48+ weight_grad = paddle . matmul (input_reshape , sub_grad , transpose_x = True )
4949 return input_grad .reshape (input_ .shape ), weight_grad
5050
5151
@@ -65,7 +65,11 @@ def backward(ctx, dy):
6565 rank = paddle .distributed .get_rank ()
6666 hcom_name = ctx .group .process_group .get_comm_name (rank )
6767
68- d_weight = input_ .reshape ([- 1 , input_ .shape [- 1 ]]).t ().matmul (sub_grad ) if not weight .stop_gradient else None
68+ d_weight = (
69+ paddle .matmul (input_ .reshape ([- 1 , input_ .shape [- 1 ]]), sub_grad , transpose_x = True )
70+ if not weight .stop_gradient
71+ else None
72+ )
6973 d_input = paddle_custom_device .npu .fused_mm_allreduce (
7074 sub_grad , weight .t (), bias = None , hcom = hcom_name , reduce_op = "sum" , comm_turn = 0
7175 )
0 commit comments