@@ -52,6 +52,7 @@ function llama_case_list_auto() {
5252 llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1
5353 llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2
5454 llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
55+ llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
5556}
5657
5758function gpt_case_list_auto_pir() {
@@ -1168,6 +1169,75 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
11681169 echo " =========== $FUNCNAME run end ==========="
11691170}
11701171
1172+ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2() {
1173+ echo " =========== $FUNCNAME run begin ==========="
1174+ export PYTHONPATH=$root_path /:$PYTHONPATH
1175+ export FLAGS_call_stack_level=2
1176+
1177+ task_name=" llama_auto_bs16_fp16_dp2mp2pp2vpp2sharding2"
1178+ case_out_dir=" output/$task_name "
1179+ case_log_dir=" output/$task_name " " _log"
1180+ rm -rf $case_out_dir
1181+ rm -rf $case_log_dir
1182+
1183+ python -u -m paddle.distributed.launch --gpus " 0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \
1184+ --model_type " llama" \
1185+ --model_name_or_path " facebook/llama-7b" \
1186+ --tokenizer_name_or_path " facebook/llama-7b" \
1187+ --hidden_size 1024 \
1188+ --intermediate_size 3072 \
1189+ --num_hidden_layers 8 \
1190+ --num_attention_heads 32 \
1191+ --input_dir " ./data" \
1192+ --output_dir $case_out_dir \
1193+ --split 949,50,1 \
1194+ --max_seq_length 2048 \
1195+ --per_device_train_batch_size 1 \
1196+ --per_device_eval_batch_size 8 \
1197+ --gradient_accumulation_steps 8 \
1198+ --use_flash_attention 0 \
1199+ --use_fused_rms_norm 0 \
1200+ --fp16 1 \
1201+ --fp16_opt_level " O2" \
1202+ --amp_master_grad 1 \
1203+ --scale_loss 1024 \
1204+ --tensor_parallel_degree 2 \
1205+ --pipeline_parallel_degree 2 \
1206+ --virtual_pp_degree 2 \
1207+ --pipeline_schedule_mode " VPP" \
1208+ --sharding_parallel_degree 2 \
1209+ --sharding " stage2" \
1210+ --learning_rate 0.0001 \
1211+ --min_learning_rate 0.00001 \
1212+ --max_steps 10 \
1213+ --save_steps 5000 \
1214+ --weight_decay 0.01 \
1215+ --warmup_ratio 0.01 \
1216+ --max_grad_norm 1.0 \
1217+ --logging_steps 1 \
1218+ --dataloader_num_workers 1 \
1219+ --eval_steps 1000 \
1220+ --report_to " visualdl" \
1221+ --disable_tqdm true \
1222+ --continue_training 0 \
1223+ --recompute 1 \
1224+ --do_train \
1225+ --do_eval \
1226+ --device " gpu" \
1227+ --data_impl " mmap" \
1228+ --parallel_mode " auto" \
1229+ >> ${log_path} /$FUNCNAME 2>&1
1230+ loss=` cat $case_log_dir /workerlog.3 | grep ' global_step: 10' | awk -F ' loss: ' ' {print $2}' | awk -F ' ,' ' {print $1}' `
1231+ ips=-1
1232+ mem=-1
1233+ echo " result: loss=$loss ips=$ips mem=$mem "
1234+ loss_base=10.0859375
1235+ ips_base=-1
1236+ mem_base=-1
1237+ check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
1238+ echo " =========== $FUNCNAME run end ==========="
1239+ }
1240+
11711241function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() {
11721242 echo " =========== $FUNCNAME run begin ==========="
11731243 export PYTHONPATH=$root_path /:$PYTHONPATH
@@ -1233,7 +1303,6 @@ function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() {
12331303 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
12341304 echo " =========== $FUNCNAME run end ==========="
12351305}
1236-
12371306# ########### case end ############
12381307
12391308function check_result() {
0 commit comments