Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,9 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并
--optim
优化器名称,默认为adamw,(`str`, 可选,默认为 `adamw`)
The optimizer to use. (default: adamw)
可能的值为:
- `"adamw"`
- `"adamw_mini"`

--report_to
日志可视化显示,默认使用visualdl可视化展示。(可选,默认为 None,展示所有)
Expand Down
1 change: 1 addition & 0 deletions llm/docs/dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo
- `unified_checkpoint`: 是否使用统一的 checkpoint,默认为 `True`。
- `autotuner_benchmark`: 是否启用 autotuner 基准测试,默认为 `False`。
- `benchmark`: 是否开启基准测试,默认为 `False`。
- `optim`:默认为`adamw`,支持`adamw`, `adamw_mini`。
### DPO 参数(DPOArguments)
- `beta`: DPO 损失函数的 beta 参数,默认为 0.1。
- `simpo_gamma`: SimPO 损失函数的 gamma 参数,默认为 0.5。
Expand Down
1 change: 1 addition & 0 deletions llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ python merge_lora_params.py \
- `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个 pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。
- `sharding_parallel_degree`: 表示分组参数切片的数据并行大小. 默认值1, 表示不启用分组参数切片的数据并行。
- `sharding`:是否使用 Paddle 的 Sharding 数据并行功能,用户的参数。支持 sharding `stage1`, `stage2` or `stage3`。其中`stage2``stage3`可以和`offload`组合使用。
- `optim`:默认为`adamw`,支持`adamw`, `adamw_mini`。
</div>


Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,6 +1897,11 @@

optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAMW_MINI:
from ..utils import AdamWMini

Check warning on line 1901 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1900-L1901

Added lines #L1900 - L1901 were not covered by tests

optimizer_cls = AdamWMini
optimizer_kwargs.update(adam_kwargs)

Check warning on line 1904 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1903-L1904

Added lines #L1903 - L1904 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否可以做一些限制或者提示,例如tp、sharding情况下不能开启 AdamWMini

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class OptimizerNames(ExplicitEnum):

ADAMW = "adamw"
ADAFACTOR = "adafactor"
ADAMW_MINI = "adamw_mini"


class ShardingOption(ExplicitEnum):
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,8 @@
raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")

self.optim = OptimizerNames(self.optim)
if self.optim == OptimizerNames.ADAMW_MINI and self.tensor_parallel_degree > 1:
raise ValueError("AdamW Mini currently doesn't support tensor parallelism.")

Check warning on line 1018 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1018

Added line #L1018 was not covered by tests

self.use_hybrid_parallel = False

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .import_utils import *
from .infohub import infohub
from .initializer import to
from .optimizer import *
from .serialization import load_torch

# hack impl for EagerParamBase to function
Expand Down
151 changes: 151 additions & 0 deletions paddlenlp/utils/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import pir
from paddle.base import core, framework
from paddle.base.framework import Variable, in_dynamic_or_pir_mode, in_pir_mode
from paddle.base.libpaddle import DataType
from paddle.optimizer.adamw import AdamW
from paddle.pir import Value


class AdamWMini(AdamW):
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if self._is_dtype_fp16_or_bf16(acc_dtype):
acc_dtype = DataType.FLOAT32 if in_pir_mode() else paddle.float32

Check warning on line 28 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L26-L28

Added lines #L26 - L28 were not covered by tests

self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)

Check warning on line 30 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L30

Added line #L30 was not covered by tests
# change moment2
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype, shape=[1])
try:
type = core.VarDesc.VarType.DENSE_TENSOR
except:
type = core.VarDesc.VarType.LOD_TENSOR
self._add_accumulator(

Check warning on line 37 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L32-L37

Added lines #L32 - L37 were not covered by tests
name=self._beta1_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, (Variable, Value)) else self._beta1,
shape=[1],
type=type,
device="cpu",
)
self._add_accumulator(

Check warning on line 46 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L46

Added line #L46 was not covered by tests
name=self._beta2_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, (Variable, Value)) else self._beta2,
shape=[1],
type=type,
device="cpu",
)

def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, (framework.Block, pir.Block))
if isinstance(param_and_grad, dict):
param_and_grad = self._update_param_group(param_and_grad)
param = param_and_grad[0]

Check warning on line 60 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L57-L60

Added lines #L57 - L60 were not covered by tests

# Whether we should do weight decay for the parameter.
with_decay = True
if self._apply_decay_param_fun is not None and not self._apply_decay_param_fun(param.name):
with_decay = False

Check warning on line 65 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L63-L65

Added lines #L63 - L65 were not covered by tests

moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0])
moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0])
beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0])
beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0])
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
master_weight = self._master_weights[param_and_grad[0].name] if find_master else None
lr = self._create_param_lr(param_and_grad)

Check warning on line 73 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L67-L73

Added lines #L67 - L73 were not covered by tests
# create the adamw optimize op
if in_dynamic_or_pir_mode():
lr_ratio_ = 1.0 if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])

Check warning on line 76 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L75-L76

Added lines #L75 - L76 were not covered by tests

_beta1 = self._beta1 if not isinstance(self._beta1, Variable) else self._beta1.item(0)
_beta2 = self._beta2 if not isinstance(self._beta2, Variable) else self._beta2.item(0)

Check warning on line 79 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L78-L79

Added lines #L78 - L79 were not covered by tests

found_inf = self._get_auxiliary_var("found_inf") if in_pir_mode() else None
self.adamw_python(

Check warning on line 82 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L81-L82

Added lines #L81 - L82 were not covered by tests
param_and_grad[0],
param_and_grad[1],
lr,
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
self._epsilon,
lr_ratio_,
self._weight_decay,
with_decay,
find_master,
)
return None

Check warning on line 100 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L100

Added line #L100 was not covered by tests
else:
raise NotImplementedError("Not implemented yet.")

Check warning on line 102 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L102

Added line #L102 was not covered by tests

def adamw_python(
self,
param,
grad,
learning_rate,
moment1,
moment2,
beta1_pow,
beta2_pow,
master_weight,
skip_update,
beta1,
beta2,
epsilon,
lr_ratio,
coeff,
with_decay,
multi_precision,
):
if skip_update:
return
if not with_decay:
coeff = 0.0
if not multi_precision:
master_weight = None
lr = learning_rate * lr_ratio
if master_weight is not None:
p = master_weight

Check warning on line 131 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L123-L131

Added lines #L123 - L131 were not covered by tests
else:
p = param
p *= 1.0 - lr * coeff
mom1 = moment1
mom2 = moment2

Check warning on line 136 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L133-L136

Added lines #L133 - L136 were not covered by tests

mom1 = beta1 * mom1 + (1.0 - beta1) * grad
mom2 = beta2 * mom2 + (1.0 - beta2) * (grad * grad).mean()
denom = mom2.sqrt() / (1.0 - beta2_pow).sqrt() + epsilon
p += (moment1 / denom) * (-(lr / (1.0 - beta1_pow)))
if master_weight is not None:
master_weight[:] = p
param[:] = p.astype(param.dtype)

Check warning on line 144 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L138-L144

Added lines #L138 - L144 were not covered by tests
else:
param[:] = p
moment1[:] = mom1
moment2[:] = mom2
beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:]

Check warning on line 149 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L146-L149

Added lines #L146 - L149 were not covered by tests
# 看看怎么更新
return

Check warning on line 151 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L151

Added line #L151 was not covered by tests
35 changes: 35 additions & 0 deletions tests/fixtures/llm/adamw_mini.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
finetune:
base:
dataset_name_or_path: "./data"
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
per_device_eval_batch_size: 8
eval_accumulation_steps: 16
num_train_epochs: 3
learning_rate: 3e-05
warmup_steps: 30
logging_steps: 1
evaluation_strategy: "epoch"
save_strategy: "epoch"
src_length: 1024
max_length: 2048
fp16: true
fp16_opt_level: "O2"
do_train: true
do_eval: true
use_flash_attention: true
disable_tqdm: true
load_best_model_at_end: true
eval_with_do_generation: false
metric_for_best_model: "accuracy"
recompute: true
refined_recompute: "flash_attn:-1"
save_total_limit: 1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
ignore_save_lr_and_optim: 1
optim: "adamw_mini"

default:
llama:
model_name_or_path: __internal_testing__/tiny-random-llama
53 changes: 53 additions & 0 deletions tests/llm/test_adamw_mini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022->2024

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import sys
import unittest

from parameterized import parameterized_class

from tests.testing_utils import argv_context_guard, load_test_config

from .testing_utils import LLMTest


@parameterized_class(
["model_dir"],
[
["llama"],
],
)
class FinetuneTest(LLMTest, unittest.TestCase):
config_path: str = "./tests/fixtures/llm/adamw_mini.yaml"
model_dir: str = None

def setUp(self) -> None:
LLMTest.setUp(self)

sys.path.insert(0, self.model_dir)

def tearDown(self) -> None:
LLMTest.tearDown(self)

def test_finetune(self):
finetune_config = load_test_config(self.config_path, "finetune", self.model_dir)

finetune_config["dataset_name_or_path"] = self.data_dir
finetune_config["output_dir"] = self.output_dir

with argv_context_guard(finetune_config):
from run_finetune import main

main()