Skip to content

Conversation

@AndSonder
Copy link
Contributor

@AndSonder AndSonder commented Apr 9, 2024

PR Category

Auto Parallel

PR Types

Improvements

Description

当下 use_reentrant == True 时会使用 PyLayer 来实现。但 PyLayer 目前不支持以 dict 形式传入 Tensor 类型参数(因为以 dict 形式传入的 Tensor 不会创建反向节点、反向边)

为了提升分布式训练的易用性,本 PR支持当 use_reentrant == True 时 recompute 使用 dict 形式传入 Tensor 类型参数。主要思路为 将 position-args + keyword-args 重排成 position-args

性能测试数据如下:

测试环境:4 卡 3090,Llama2 模型 num_hidden_layer hack 为 4

收集第30个step的性能数据:

Case interval_runtime interval_samples_per_second interval_steps_per_second Loss(step30)
Llama2(不使用 kwargs) 10.3995 1.5394 0.0962 7.09293509
Llama2(使用 kwargs) 10.4043 1.5378 0.0961 7.09293509
GPT3(不使用kwargs) 2.4434 1.6371 0.4093 10.3778286
GPT3(使用kwargs) 2.4371 1.6413 0.4103 10.3778286

Llama2 测试脚本如下:

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # just for debug set -x unset CUDA_VISIBLE_DEVICES task_name="llama_auto_dp2mp2pp2" rm -rf output/$task_name/ rm -rf "output/$task_name""_log" export SOT_LOG_LEVEL=4 export PYTHONPATH=../../../:$PYTHONPATH # ulimit -c unlimited # export GLOG_v=4 # export FLAGS_call_stack_level=3 # export FLAGS_use_cuda_managed_memory=true # export FLAGS_embedding_deterministic=1 # export FLAGS_cudnn_deterministic=1 # export NVIDIA_TF32_OVERRIDE=0 to_static=0 # 是否开启动转静训练 python -u -m paddle.distributed.launch \ --gpus "0,1,2,3" \ --log_dir "auto_3d" \ run_pretrain_auto.py \ --model_type "llama" \ --model_name_or_path "facebook/llama-7b" \ --tokenizer_name_or_path "facebook/llama-7b" \ --input_dir "../data" \ --output_dir "output/$task_name" \ --split 949,50,1 \ --max_seq_length 2048 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 16 \ --use_flash_attention 0 \ --use_fused_rms_norm 0 \ --fp16 0 \ --fp16_opt_level "O2" \ --scale_loss 1024 \ --pipeline_parallel_degree 4 \ --tensor_parallel_degree 1 \ --sharding_parallel_degree 1 \ --learning_rate 0.0001 \ --min_learning_rate 0.00001 \ --max_steps 30 \ --save_steps 5000000 \ --weight_decay 0.01 \ --warmup_ratio 0.01 \ --logging_steps 1\ --dataloader_num_workers 1 \ --sharding_parallel_degree 1 \ --sharding "stage1" \ --eval_steps 1000000 \ --disable_tqdm true \ --continue_training 0 \ --recompute 1 \ --recompute_granularity full \ --do_train \ --do_eval \ --device "gpu" \ --data_impl "mmap" \ --enable_auto_parallel 1 \ --max_grad_norm 1.0 \ --to_static $to_static \ 

修改方式 paddlenlp/transformers/llama/modeling_auto.py 中所有启用 recompute 的地方(一共3处)

image

image

image

GPT 运行脚本如下:

export PYTHONPATH="../../../":$PYTHONPATH export FLAGS_cudnn_deterministic=1 export FLAGS_embedding_deterministic=1 export NVIDIA_TF32_OVERRIDE=0 export FLAGS_call_stack_level=3 to_static=0 # export TRANSLATOR_DISABLE_NEW_ERROR=0 # export TRANSLATOR_CODE_LEVEL=100 task_name="gpt3_auto_dp2mp2pp2_${to_static}" log_dir="output/$task_name""_log" output_dir="output/$task_name" rm -rf $log_dir rm -rf $output_dir python -u -m paddle.distributed.launch \ --gpus "0,1,2,3" \ --log_dir ${log_dir} \ run_pretrain_auto.py \ --model_name_or_path gpt2-medium-en \ --tokenizer_name_or_path gpt2-medium-en \ --input_dir "../data" \ --output_dir ${output_dir} \ --split 949,50,1 \ --max_seq_length 1024 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --sharding "" \ --tensor_parallel_degree 2 \ --pipeline_parallel_degree 2 \ --sequence_parallel 0 \ --fuse_attention_qkv 0 \ --use_flash_attention 0 \ --scale_loss 1024 \ --learning_rate 0.00001 \ --min_learning_rate 0.000005 \ --max_steps 30 \ --save_steps 50000 \ --weight_decay 0.01 \ --warmup_ratio 0.01 \ --max_grad_norm 1.0 \ --logging_steps 1\ --continue_training 0\ --dataloader_num_workers 1 \ --eval_steps 100000 \ --report_to "visualdl" \ --disable_tqdm true \ --recompute 0 \ --gradient_accumulation_steps 4 \ --do_train \ --do_eval \ --device "gpu" \ --model_type "gpt" \ --enable_auto_parallel 1 \ --to_static ${to_static} \ --fp16 0 \ --fp16_opt_level "O2" \ 

paddlenlp/transformers/gpt/modeling_auto.py 修改如下:

image

@paddle-bot
Copy link

paddle-bot bot commented Apr 9, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 9, 2024
@AndSonder AndSonder marked this pull request as ready for review April 9, 2024 01:55
@AndSonder
Copy link
Contributor Author

@ForFishes @MarioLulab CI 都问题了,麻烦研发老师 review 一下 ~

@ForFishes
Copy link
Member

您好,这个pr涉及到一些问题,内部需要进一步讨论这个问题。

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +533 to +543
input_args = args
# rearrange `position-args + keyword-args` into `position-args`
if isinstance(function, paddle.nn.Layer):
dyfunc_sig = inspect.signature(function.forward)
else:
dyfunc_sig = inspect.signature(function)

bound_args = dyfunc_sig.bind(*args, **kwargs)
bound_args.apply_defaults()
input_args = list(bound_args.arguments.values())
return RecomputeFunction.apply(function, preserve, *input_args)
Copy link
Member

@SigureMo SigureMo Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于如下情况

# 摘自 PaddleMIX:https://github.com/PaddlePaddle/PaddleMIX/blob/8b896d533811a3500af3064c5f1952b77003d4c8/ppdiffusers/ppdiffusers/models/unet_2d_blocks.py#L1149-L1155 def custom_forward(*inputs): ...

使用 bound_args.arguments 是错误的,无论传入多少个值,bound_args.arguments 只有一个值,就是打包后的 inputs

需要考虑所有 Parameter kind

import inspect def custom_forward(*inputs, **kwargs): return inputs def convert_inputs_to_positional_args(fn, *args, **kwargs): positional_args = [] sig = inspect.signature(fn) bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() for arg, param in zip(bound_args.arguments.values(), sig.parameters.values()): if param.kind == param.VAR_POSITIONAL: positional_args.extend(arg) elif param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): positional_args.append(arg) elif param.kind == param.VAR_KEYWORD: positional_args.extend(arg.values()) elif param.kind == param.KEYWORD_ONLY: raise ValueError("Currently, keyword-only arguments are not supported.") else: raise ValueError("Unknown parameter kind.") return positional_args convert_inputs_to_positional_args(custom_forward, 1, 2, y=2, x=1)

主要思路为 将 position-args + keyword-args 重排成 position-args

注意该方案天生不支持 keyword-only 的函数,如果需要支持那么这个方案是不可行的

另外,本 PR 已经影响了高优监控模型 Stable Diffusion,我先提一个 PR 尝试 revert(#63637),可以同时看看怎么修复

@AndSonder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,收到

另外有一个问题想确认一下,对 Stable Diffusion 的影响是上述的 case 发生报错,还是其他问题呢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上述 case

@luotao1 luotao1 changed the title 【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True 【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True -part Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

5 participants