Skip to content

Commit 50b4491

Browse files
update for npu. (#8210)
1 parent 57b22e7 commit 50b4491

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

paddlenlp/peft/lora/mc2_lora_npu.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ def backward(ctx, dy):
4040
input_, weight = ctx.saved_tensor()
4141
out_grad = dy
4242
sub_grad = out_grad.reshape([-1, out_grad.shape[-1]])
43-
input_grad = paddle.matmul(sub_grad, weight.t())
43+
input_grad = paddle.matmul(sub_grad, weight, transpose_y=True)
4444
if weight.stop_gradient:
4545
return input_grad.reshape(input_.shape)
4646
else:
4747
input_reshape = input_.reshape([-1, input_.shape[-1]])
48-
weight_grad = input_reshape.t().matmul(sub_grad)
48+
weight_grad = paddle.matmul(input_reshape, sub_grad, transpose_x=True)
4949
return input_grad.reshape(input_.shape), weight_grad
5050

5151

@@ -65,7 +65,11 @@ def backward(ctx, dy):
6565
rank = paddle.distributed.get_rank()
6666
hcom_name = ctx.group.process_group.get_comm_name(rank)
6767

68-
d_weight = input_.reshape([-1, input_.shape[-1]]).t().matmul(sub_grad) if not weight.stop_gradient else None
68+
d_weight = (
69+
paddle.matmul(input_.reshape([-1, input_.shape[-1]]), sub_grad, transpose_x=True)
70+
if not weight.stop_gradient
71+
else None
72+
)
6973
d_input = paddle_custom_device.npu.fused_mm_allreduce(
7074
sub_grad, weight.t(), bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0
7175
)

0 commit comments

Comments
 (0)