22import copy , time
33from datetime import datetime
44import torch
5+ from queue import Queue
56import re , os , logging
67from threading import Thread
78import contextlib
@@ -539,9 +540,11 @@ def predict_stream(**params):
539540 force_words_ids = params ["force_words_ids" ] if "force_words_ids" in params else None
540541 use_hpu_graphs = params ["use_hpu_graphs" ] if "use_hpu_graphs" in params else False
541542 use_cache = params ["use_cache" ] if "use_cache" in params else True
543+ return_stats = params ["return_stats" ] if "return_stats" in params else False
542544 prompt = params ["prompt" ]
543545 model = MODELS [model_name ]["model" ]
544546 tokenizer = MODELS [model_name ]["tokenizer" ]
547+ errors_queue = Queue ()
545548 task = params .get ("task" , "" )
546549
547550 if task != "" :
@@ -586,25 +589,28 @@ def predict_stream(**params):
586589 )
587590
588591 def generate_output ():
589- with torch .no_grad ():
590- with torch .cpu .amp .autocast (
591- enabled = True , dtype = torch .bfloat16 , cache_enabled = True
592- ):
593- generation_kwargs = dict (
594- streamer = streamer ,
595- generation_config = generation_config ,
596- return_dict_in_generate = True ,
597- )
598- generation_kwargs ["stopping_criteria" ] = StoppingCriteriaList (
599- [
600- StopOnTokens (
601- min_length = max (max_new_tokens - 20 , 0 ),
602- start_length = input_token_len ,
603- stop_token_id = stop_token_ids ,
604- )
605- ]
606- )
607- return model .generate (** input_tokens , ** generation_kwargs )
592+ try :
593+ with torch .no_grad ():
594+ with torch .cpu .amp .autocast (
595+ enabled = True , dtype = torch .bfloat16 , cache_enabled = True
596+ ):
597+ generation_kwargs = dict (
598+ streamer = streamer ,
599+ generation_config = generation_config ,
600+ return_dict_in_generate = True ,
601+ )
602+ generation_kwargs ["stopping_criteria" ] = StoppingCriteriaList (
603+ [
604+ StopOnTokens (
605+ min_length = max (max_new_tokens - 20 , 0 ),
606+ start_length = input_token_len ,
607+ stop_token_id = stop_token_ids ,
608+ )
609+ ]
610+ )
611+ return model .generate (** input_tokens , ** generation_kwargs )
612+ except Exception as e :
613+ errors_queue .put (e )
608614
609615 generation_thread = Thread (target = generate_output )
610616 generation_thread .start ()
@@ -655,21 +661,23 @@ def generate_output():
655661 # generation_config.top_p = top_p
656662 generation_config .temperature = temperature
657663 generation_config .repetition_penalty = repetition_penalty
658-
659664 def generate_output ():
660- with torch .no_grad ():
661- return model .generate (
662- ** input_tokens ,
663- ** generate_kwargs ,
664- streamer = streamer ,
665- generation_config = generation_config ,
666- return_dict_in_generate = True ,
667- output_scores = True ,
668- max_new_tokens = max_new_tokens ,
669- lazy_mode = True ,
670- hpu_graphs = use_hpu_graphs ,
671- ignore_eos = False ,
672- )
665+ try :
666+ with torch .no_grad ():
667+ return model .generate (
668+ ** input_tokens ,
669+ ** generate_kwargs ,
670+ streamer = streamer ,
671+ generation_config = generation_config ,
672+ return_dict_in_generate = True ,
673+ output_scores = True ,
674+ max_new_tokens = max_new_tokens ,
675+ lazy_mode = True ,
676+ hpu_graphs = use_hpu_graphs ,
677+ ignore_eos = False ,
678+ )
679+ except Exception as e :
680+ errors_queue .put (e )
673681
674682 generation_thread = Thread (target = generate_output )
675683 generation_thread .start ()
@@ -679,6 +687,14 @@ def generate_output():
679687 )
680688 output_word_len = 0
681689
690+ generation_thread .join (0.1 )
691+ if generation_thread .is_alive ():
692+ pass
693+ else :
694+ thread_exception = errors_queue .get ()
695+ raise thread_exception
696+ # prevent crash if no words are coming out
697+ first_token_output_time = datetime .now ()
682698 for new_text in streamer :
683699 if len (new_text ) == 0 :
684700 continue
@@ -697,14 +713,15 @@ def generate_output():
697713 if output_word_len != 1
698714 else 0
699715 )
700- stats = {
701- "input_token_len" : input_token_len ,
702- "output_word_len" : output_word_len ,
703- "duration" : duration ,
704- "first_word_latency" : first_word_latency ,
705- "msecond_per_word" : msecond_per_word ,
706- }
707- yield "END_OF_STREAM_STATS={}" .format (stats )
716+ if return_stats :
717+ stats = {
718+ "input_token_len" : input_token_len ,
719+ "output_word_len" : output_word_len ,
720+ "duration" : duration ,
721+ "first_word_latency" : first_word_latency ,
722+ "msecond_per_word" : msecond_per_word ,
723+ }
724+ yield "END_OF_STREAM_STATS={}" .format (stats )
708725
709726
710727def predict (** params ):
0 commit comments