@@ -22,7 +22,7 @@ def skip(*args, **kwargs):
2222torch .nn .init .kaiming_uniform_ = skip
2323torch .nn .init .uniform_ = skip
2424torch .nn .init .normal_ = skip
25-
25+ import re
2626import torch .distributed as dist
2727from torch .nn .parallel import DistributedDataParallel as DDP
2828import logging
@@ -69,7 +69,7 @@ def parse_args():
6969 parser .add_argument (
7070 "--calibration_dataset_name" ,
7171 type = str ,
72- default = "wikitext-2-raw-v1" ,
72+ default = "NeelNanda/pile-10k" , # e.g. wikitext-2-raw-v1
7373 help = "The name of the pruning dataset to use (via the datasets library)." ,
7474 )
7575 parser .add_argument (
@@ -128,6 +128,12 @@ def parse_args():
128128 default = 16 ,
129129 help = "Batch size (per device) for the evaluation dataloader." ,
130130 )
131+ parser .add_argument (
132+ "--calib_size" ,
133+ type = int ,
134+ default = 128 ,
135+ help = "sample size for the calibration dataset." ,
136+ )
131137 parser .add_argument (
132138 "--learning_rate" ,
133139 type = float ,
@@ -268,10 +274,12 @@ def parse_args():
268274 parser .add_argument ("--tasks" , default = ["lambada_openai" ],
269275 help = "Usually chosen with ['lambada_openai','hellaswag','winogrande','piqa']" ,
270276 )
271- parser .add_argument ("--eval_fp16" , action = 'store_true' ,
272- help = " fp16" )
273277 parser .add_argument ("--use_accelerate" , action = 'store_true' ,
274- help = "Usually use to accelerate evaluation for large models" )
278+ help = "Usually use to accelerate evaluation for large models"
279+ )
280+ parser .add_argument ("--eval_dtype" , default = 'fp32' ,
281+ help = "choose in bf16, fp16 and fp32"
282+ )
275283
276284 args = parser .parse_args ()
277285
@@ -376,34 +384,33 @@ def main():
376384 logger .warning ("You are instantiating a new config instance from scratch." )
377385
378386 is_llama = bool ("llama" in args .model_name_or_path )
379- is_t5 = bool ("t5" in args .model_name_or_path )
380387 if args .tokenizer_name :
381388 tokenizer = AutoTokenizer .from_pretrained (args .tokenizer_name , use_fast = not args .use_slow_tokenizer )
382389 elif args .model_name_or_path :
383390 if is_llama :
384391 tokenizer = transformers .LlamaTokenizer .from_pretrained (args .model_name_or_path )
385392 else :
386- tokenizer = AutoTokenizer .from_pretrained (args .model_name_or_path , use_fast = not args .use_slow_tokenizer , trust_remote_code = True )
393+ tokenizer = AutoTokenizer .from_pretrained (args .model_name_or_path ,
394+ use_fast = not args .use_slow_tokenizer , trust_remote_code = True )
387395 else :
388396 raise ValueError (
389397 "You are instantiating a new tokenizer from scratch. This is not supported by this script."
390398 "You can do it from another script, save it, and load it from here, using --tokenizer_name."
391399 )
392400
393401 if args .model_name_or_path :
394- if is_t5 :
395- model = T5ForConditionalGeneration .from_pretrained (
396- args .model_name_or_path ,
397- config = config ,
398- )
402+ if re .search ("chatglm" , args .model_name_or_path .lower ()):
403+ model = AutoModel .from_pretrained (args .model_name_or_path ,
404+ trust_remote_code = args .trust_remote_code ) # .half()
399405 else :
400406 model = AutoModelForCausalLM .from_pretrained (
401407 args .model_name_or_path ,
402408 from_tf = bool (".ckpt" in args .model_name_or_path ),
403409 config = config ,
404410 trust_remote_code = args .trust_remote_code ,
405- low_cpu_mem_usage = args .low_cpu_mem_usage ,
411+ low_cpu_mem_usage = args .low_cpu_mem_usage
406412 )
413+
407414
408415 else :
409416 logger .info ("Training new model from scratch" )
@@ -492,7 +499,7 @@ def group_texts(examples):
492499 train_dataset = lm_datasets ["train" ]
493500
494501 # DataLoaders creation:
495- train_dataset = train_dataset .shuffle (seed = 42 ).select (range (128 ))
502+ train_dataset = train_dataset .shuffle (seed = 42 ).select (range (args . calib_size ))
496503 total_batch_size = args .per_device_train_batch_size
497504 if local_rank != - 1 :
498505 total_batch_size *= WORLD_SIZE
@@ -543,8 +550,10 @@ def group_texts(examples):
543550 torch .backends .cudnn .allow_tf32 = False
544551 use_cache = model .config .use_cache
545552 model .config .use_cache = False
546-
553+ import time
554+ s = time .time ()
547555 pruning = prepare_pruning (model , configs , dataloader = train_dataloader , device = device )
556+ logger .info (f"cost time: { time .time () - s } " )
548557 model .config .use_cache = use_cache
549558
550559 if args .output_dir is not None :
@@ -555,20 +564,28 @@ def group_texts(examples):
555564 logger .info (f"The model has been exported to { output_dir } " )
556565
557566 if device != 'cpu' :
558- model = model .to (device )
567+ if not args .use_accelerate :
568+ model = model .to (device )
569+ else :
570+ model = model .cpu ()
559571 logger .info (f"***** Evaluation in GPU mode. *****" )
560572 else :
561573 logger .info (f"***** Evaluation in CPU mode. *****" )
562574 model .eval ()
563575
564576 model_name = args .model_name_or_path
565- dtype = 'float32'
566- if args .eval_fp16 :
567- if (hasattr (model , 'config' ) and model .config .torch_dtype is torch .bfloat16 ):
568- dtype = 'bfloat16'
569- else :
570- dtype = 'float16'
571- model_args = f'pretrained={ model_name } ,tokenizer={ model_name } ,dtype={ dtype } ,use_accelerate={ args .use_accelerate } '
577+ dtype = None
578+ if args .eval_dtype == 'bf16' :
579+ model = model .to (dtype = torch .bfloat16 )
580+ dtype = 'bfloat16'
581+ elif args .eval_dtype == 'fp16' :
582+ dtype = 'float16'
583+ model = model .to (dtype = torch .float16 )
584+ else :
585+ dtype = 'float32'
586+ model = model .to (dtype = torch .float32 )
587+
588+ model_args = f'pretrained={ model_name } ,tokenizer={ model_name } ,dtype={ dtype } ,use_accelerate={ args .use_accelerate } ,trust_remote_code={ args .trust_remote_code } '
572589 eval_batch = args .per_device_eval_batch_size
573590 user_model = None if args .use_accelerate else model
574591 results = evaluate (
0 commit comments