@@ -284,6 +284,15 @@ def finetune(self):
284284 raise NotImplementedError (
285285 f"Unsupported bits { finetune_args .bits } , only support 4 and 8 now."
286286 )
287+ if finetune_args .full_finetune :
288+ raise ValueError (
289+ f"qlora and full_finetune can't be True at the same time."
290+ )
291+ elif finetune_args .full_finetune :
292+ if finetune_args .bits not in [16 , 32 ]:
293+ raise ValueError (
294+ f"full finetune only support 16 and 32 bits."
295+ )
287296
288297 config = self .load_model_config (self .model_args )
289298 if config .architectures [0 ].endswith ("ForCausalLM" ) \
@@ -482,48 +491,50 @@ def concatenate_data(dataset, max_seq_length):
482491 )
483492
484493 if training_args .do_train :
485- # PEFT settings
486- if finetune_args .peft == "lora" :
487- if finetune_args .lora_all_linear :
488- target_modules = self .find_all_linear_names (model )
489- else :
490- target_modules = finetune_args .lora_target_modules
491-
492- peft_config = LoraConfig (
493- r = finetune_args .lora_rank ,
494- lora_alpha = finetune_args .lora_alpha ,
495- lora_dropout = finetune_args .lora_dropout ,
496- target_modules = target_modules ,
497- bias = "none" ,
498- task_type = TaskType .CAUSAL_LM ,
499- )
500- elif finetune_args .peft == "llama_adapter" :
501- peft_config = AdaptionPromptConfig (
502- adapter_layers = finetune_args .adapter_layers ,
503- adapter_len = finetune_args .adapter_len ,
504- task_type = "CAUSAL_LM" ,
505- )
506- elif finetune_args .peft == "ptun" :
507- peft_config = PromptEncoderConfig (
508- num_virtual_tokens = finetune_args .num_virtual_tokens ,
509- encoder_hidden_size = finetune_args .ptun_hidden_size ,
510- task_type = "CAUSAL_LM" ,
511- )
512- elif finetune_args .peft == "prefix" :
513- peft_config = PrefixTuningConfig (
514- num_virtual_tokens = finetune_args .num_virtual_tokens ,
515- task_type = "CAUSAL_LM" ,
516- )
517- elif finetune_args .peft == "prompt" :
518- peft_config = PromptTuningConfig (
519- num_virtual_tokens = finetune_args .num_virtual_tokens ,
520- task_type = "CAUSAL_LM" ,
521- )
494+ if not finetune_args .full_finetune :
495+ # PEFT settings
496+ if finetune_args .peft == "lora" :
497+ if finetune_args .lora_all_linear :
498+ target_modules = self .find_all_linear_names (model )
499+ else :
500+ target_modules = finetune_args .lora_target_modules
501+
502+ peft_config = LoraConfig (
503+ r = finetune_args .lora_rank ,
504+ lora_alpha = finetune_args .lora_alpha ,
505+ lora_dropout = finetune_args .lora_dropout ,
506+ target_modules = target_modules ,
507+ bias = "none" ,
508+ task_type = TaskType .CAUSAL_LM ,
509+ )
510+ elif finetune_args .peft == "llama_adapter" :
511+ peft_config = AdaptionPromptConfig (
512+ adapter_layers = finetune_args .adapter_layers ,
513+ adapter_len = finetune_args .adapter_len ,
514+ task_type = "CAUSAL_LM" ,
515+ )
516+ elif finetune_args .peft == "ptun" :
517+ peft_config = PromptEncoderConfig (
518+ num_virtual_tokens = finetune_args .num_virtual_tokens ,
519+ encoder_hidden_size = finetune_args .ptun_hidden_size ,
520+ task_type = "CAUSAL_LM" ,
521+ )
522+ elif finetune_args .peft == "prefix" :
523+ peft_config = PrefixTuningConfig (
524+ num_virtual_tokens = finetune_args .num_virtual_tokens ,
525+ task_type = "CAUSAL_LM" ,
526+ )
527+ elif finetune_args .peft == "prompt" :
528+ peft_config = PromptTuningConfig (
529+ num_virtual_tokens = finetune_args .num_virtual_tokens ,
530+ task_type = "CAUSAL_LM" ,
531+ )
532+
533+ model = get_peft_model (model , peft_config )
534+ model .print_trainable_parameters ()
522535
523- model = get_peft_model (model , peft_config )
524536 if model_dtype == torch .bfloat16 :
525537 model = model .to (model_dtype )
526- model .print_trainable_parameters ()
527538
528539 if finetune_args .device != 'hpu' :
529540 # Initialize our Trainer
@@ -806,24 +817,33 @@ def preprocess_logits_for_metrics(logits, labels):
806817 else :
807818 raise ValueError ("Must provide model_name_or_path to load a pretrained Seq2SeqLM model." )
808819
809- # PEFT settings
810- if finetune_args .peft == "lora" :
811- if finetune_args .lora_all_linear :
812- target_modules = self .find_all_linear_names (model )
813- else :
814- target_modules = finetune_args .lora_target_modules
815- peft_config = LoraConfig (
816- r = finetune_args .lora_rank ,
817- lora_alpha = finetune_args .lora_alpha ,
818- lora_dropout = finetune_args .lora_dropout ,
819- target_modules = target_modules ,
820- bias = "none" ,
821- task_type = TaskType .SEQ_2_SEQ_LM ,
820+ if finetune_args .qlora :
821+ model = prepare_model_for_kbit_training (
822+ model , use_gradient_checkpointing = training_args .gradient_checkpointing
822823 )
823824
824- # model = prepare_model_for_int8_training(model)
825- model = get_peft_model (model , peft_config )
826- model .print_trainable_parameters ()
825+ if not finetune_args .full_finetune :
826+ # PEFT settings
827+ if finetune_args .peft == "lora" :
828+ if finetune_args .lora_all_linear :
829+ target_modules = self .find_all_linear_names (model )
830+ else :
831+ target_modules = finetune_args .lora_target_modules
832+ peft_config = LoraConfig (
833+ r = finetune_args .lora_rank ,
834+ lora_alpha = finetune_args .lora_alpha ,
835+ lora_dropout = finetune_args .lora_dropout ,
836+ target_modules = target_modules ,
837+ bias = "none" ,
838+ task_type = TaskType .SEQ_2_SEQ_LM ,
839+ )
840+
841+ # model = prepare_model_for_int8_training(model)
842+ model = get_peft_model (model , peft_config )
843+ model .print_trainable_parameters ()
844+
845+ if model_dtype == torch .bfloat16 :
846+ model = model .to (model_dtype )
827847
828848 if training_args .do_eval and not training_args .do_train :
829849 config = PeftConfig .from_pretrained (model_args .model_name_or_path )
@@ -839,9 +859,26 @@ def preprocess_logits_for_metrics(logits, labels):
839859 label_pad_token_id = label_pad_token_id ,
840860 pad_to_multiple_of = 8 )
841861
842- # Create Trainer instance
843- trainer = Seq2SeqTrainer (
862+
863+ if finetune_args .device != 'hpu' :
864+ # Create Trainer instance
865+ trainer = Seq2SeqTrainer (
866+ model = model ,
867+ args = training_args ,
868+ data_collator = data_collator ,
869+ train_dataset = train_dataset if training_args .do_train else None ,
870+ eval_dataset = eval_dataset if training_args .do_eval else None ,
871+ compute_metrics = compute_metrics ,
872+ preprocess_logits_for_metrics = preprocess_logits_for_metrics if training_args .do_eval else None ,
873+ )
874+ else :
875+ from optimum .habana import GaudiConfig , GaudiSeq2SeqTrainer # pylint: disable=E0611 E0401
876+ gaudi_config = GaudiConfig ()
877+ gaudi_config .use_fused_adam = True
878+ gaudi_config .use_fused_clip_norm = True
879+ trainer = GaudiSeq2SeqTrainer (
844880 model = model ,
881+ gaudi_config = gaudi_config ,
845882 args = training_args ,
846883 data_collator = data_collator ,
847884 train_dataset = train_dataset if training_args .do_train else None ,
@@ -850,6 +887,7 @@ def preprocess_logits_for_metrics(logits, labels):
850887 preprocess_logits_for_metrics = preprocess_logits_for_metrics if training_args .do_eval else None ,
851888 )
852889
890+
853891 # Training
854892 if training_args .do_train :
855893 checkpoint = None
0 commit comments