Skip to content

Conversation

@DesmonDay
Copy link
Contributor

PR types

New features

PR changes

Others

Description

Add embedding trainer.

@paddle-bot
Copy link

paddle-bot bot commented Dec 11, 2024

Thanks for your contribution!

Copy link
Contributor

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

from paddle.distributed import fleet


def dist_gather_tensor_with_gradient(tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以找其他 utils 的目录放吧。感觉还比较通用。gathe across dp

函数里面的 _gather_tensor_with_gradient 这个 gradient怎么体现的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

体现在这里,重新把tensor赋值进去。all_tensors[sharding_rank] = tensor

def __init__(self, model_args, **kwargs):
super().__init__(**kwargs)

self.model_args = model_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好加一个 loss type之类的字段。我看zhangjie那边还需要加inf-cl的loss

@wawltor wawltor merged commit b205922 into PaddlePaddle:develop Dec 11, 2024
9 of 12 checks passed
DesmonDay added a commit to DesmonDay/PaddleNLP that referenced this pull request Dec 12, 2024
DesmonDay added a commit that referenced this pull request Dec 13, 2024
* Fix multi-threading load_state_dict (#9464) * Update model_utils.py * Update model_utils.py * [Unified Checkpoint] fix single card loading without master weights (#9540) * update embedding trainer (#9608)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants