Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit e1da7e8

Browse files
authored
Fix crash (#141)
* feat: fix crash for first token * feat: return threads errors, return stats if requested only
1 parent 5b24317 commit e1da7e8

File tree

1 file changed

+58
-41
lines changed

1 file changed

+58
-41
lines changed

workflows/chatbot/inference/generate.py

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy, time
33
from datetime import datetime
44
import torch
5+
from queue import Queue
56
import re, os, logging
67
from threading import Thread
78
import contextlib
@@ -539,9 +540,11 @@ def predict_stream(**params):
539540
force_words_ids = params["force_words_ids"] if "force_words_ids" in params else None
540541
use_hpu_graphs = params["use_hpu_graphs"] if "use_hpu_graphs" in params else False
541542
use_cache = params["use_cache"] if "use_cache" in params else True
543+
return_stats = params["return_stats"] if "return_stats" in params else False
542544
prompt = params["prompt"]
543545
model = MODELS[model_name]["model"]
544546
tokenizer = MODELS[model_name]["tokenizer"]
547+
errors_queue = Queue()
545548
task = params.get("task", "")
546549

547550
if task != "":
@@ -586,25 +589,28 @@ def predict_stream(**params):
586589
)
587590

588591
def generate_output():
589-
with torch.no_grad():
590-
with torch.cpu.amp.autocast(
591-
enabled=True, dtype=torch.bfloat16, cache_enabled=True
592-
):
593-
generation_kwargs = dict(
594-
streamer=streamer,
595-
generation_config=generation_config,
596-
return_dict_in_generate=True,
597-
)
598-
generation_kwargs["stopping_criteria"] = StoppingCriteriaList(
599-
[
600-
StopOnTokens(
601-
min_length=max(max_new_tokens - 20, 0),
602-
start_length=input_token_len,
603-
stop_token_id=stop_token_ids,
604-
)
605-
]
606-
)
607-
return model.generate(**input_tokens, **generation_kwargs)
592+
try:
593+
with torch.no_grad():
594+
with torch.cpu.amp.autocast(
595+
enabled=True, dtype=torch.bfloat16, cache_enabled=True
596+
):
597+
generation_kwargs = dict(
598+
streamer=streamer,
599+
generation_config=generation_config,
600+
return_dict_in_generate=True,
601+
)
602+
generation_kwargs["stopping_criteria"] = StoppingCriteriaList(
603+
[
604+
StopOnTokens(
605+
min_length=max(max_new_tokens - 20, 0),
606+
start_length=input_token_len,
607+
stop_token_id=stop_token_ids,
608+
)
609+
]
610+
)
611+
return model.generate(**input_tokens, **generation_kwargs)
612+
except Exception as e:
613+
errors_queue.put(e)
608614

609615
generation_thread = Thread(target=generate_output)
610616
generation_thread.start()
@@ -655,21 +661,23 @@ def generate_output():
655661
# generation_config.top_p = top_p
656662
generation_config.temperature = temperature
657663
generation_config.repetition_penalty = repetition_penalty
658-
659664
def generate_output():
660-
with torch.no_grad():
661-
return model.generate(
662-
**input_tokens,
663-
**generate_kwargs,
664-
streamer=streamer,
665-
generation_config=generation_config,
666-
return_dict_in_generate=True,
667-
output_scores=True,
668-
max_new_tokens=max_new_tokens,
669-
lazy_mode=True,
670-
hpu_graphs=use_hpu_graphs,
671-
ignore_eos=False,
672-
)
665+
try:
666+
with torch.no_grad():
667+
return model.generate(
668+
**input_tokens,
669+
**generate_kwargs,
670+
streamer=streamer,
671+
generation_config=generation_config,
672+
return_dict_in_generate=True,
673+
output_scores=True,
674+
max_new_tokens=max_new_tokens,
675+
lazy_mode=True,
676+
hpu_graphs=use_hpu_graphs,
677+
ignore_eos=False,
678+
)
679+
except Exception as e:
680+
errors_queue.put(e)
673681

674682
generation_thread = Thread(target=generate_output)
675683
generation_thread.start()
@@ -679,6 +687,14 @@ def generate_output():
679687
)
680688
output_word_len = 0
681689

690+
generation_thread.join(0.1)
691+
if generation_thread.is_alive():
692+
pass
693+
else:
694+
thread_exception = errors_queue.get()
695+
raise thread_exception
696+
# prevent crash if no words are coming out
697+
first_token_output_time = datetime.now()
682698
for new_text in streamer:
683699
if len(new_text) == 0:
684700
continue
@@ -697,14 +713,15 @@ def generate_output():
697713
if output_word_len != 1
698714
else 0
699715
)
700-
stats = {
701-
"input_token_len": input_token_len,
702-
"output_word_len": output_word_len,
703-
"duration": duration,
704-
"first_word_latency": first_word_latency,
705-
"msecond_per_word": msecond_per_word,
706-
}
707-
yield "END_OF_STREAM_STATS={}".format(stats)
716+
if return_stats:
717+
stats = {
718+
"input_token_len": input_token_len,
719+
"output_word_len": output_word_len,
720+
"duration": duration,
721+
"first_word_latency": first_word_latency,
722+
"msecond_per_word": msecond_per_word,
723+
}
724+
yield "END_OF_STREAM_STATS={}".format(stats)
708725

709726

710727
def predict(**params):

0 commit comments

Comments
 (0)