Upload 4 files
Browse files- app.py +172 -0
- languages.py +147 -0
- requirements.txt +5 -0
- subtitle_manager.py +52 -0
app.py ADDED
| @@ -0,0 +1,172 @@ | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
| | |
| 1 | + import gradio as gr |
| 2 | + import time |
| 3 | + import logging |
| 4 | + import torch |
| 5 | + from sys import platform |
| 6 | + from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor |
| 7 | + from transformers.utils import is_flash_attn_2_available |
| 8 | + from languages import get_language_names |
| 9 | + from subtitle_manager import Subtitle |
| 10 | + |
| 11 | + |
| 12 | + logging.basicConfig(level=logging.INFO) |
| 13 | + last_model = None |
| 14 | + |
| 15 | + def write_file(output_file,subtitle): |
| 16 | + with open(output_file, 'w', encoding='utf-8') as f: |
| 17 | + f.write(subtitle) |
| 18 | + |
| 19 | + def create_pipe(model, flash): |
| 20 | + if torch.cuda.is_available(): |
| 21 | + device = "cuda:0" |
| 22 | + elif platform == "darwin": |
| 23 | + device = "mps" |
| 24 | + else: |
| 25 | + device = "cpu" |
| 26 | + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| 27 | + model_id = model |
| 28 | + |
| 29 | + model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| 30 | + model_id, |
| 31 | + torch_dtype=torch_dtype, |
| 32 | + low_cpu_mem_usage=True, |
| 33 | + use_safetensors=True, |
| 34 | + attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa", |
| 35 | + # eager (manual attention implementation) |
| 36 | + # flash_attention_2 (implementation using flash attention 2) |
| 37 | + # sdpa (implementation using torch.nn.functional.scaled_dot_product_attention) |
| 38 | + # PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1. |
| 39 | + ) |
| 40 | + model.to(device) |
| 41 | + |
| 42 | + processor = AutoProcessor.from_pretrained(model_id) |
| 43 | + |
| 44 | + pipe = pipeline( |
| 45 | + "automatic-speech-recognition", |
| 46 | + model=model, |
| 47 | + tokenizer=processor.tokenizer, |
| 48 | + feature_extractor=processor.feature_extractor, |
| 49 | + # max_new_tokens=128, |
| 50 | + # chunk_length_s=15, |
| 51 | + # batch_size=16, |
| 52 | + torch_dtype=torch_dtype, |
| 53 | + device=device, |
| 54 | + ) |
| 55 | + return pipe |
| 56 | + |
| 57 | + def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash, |
| 58 | + chunk_length_s, batch_size, progress=gr.Progress()): |
| 59 | + global last_model |
| 60 | + |
| 61 | + progress(0, desc="Loading Audio..") |
| 62 | + logging.info(f"urlData:{urlData}") |
| 63 | + logging.info(f"multipleFiles:{multipleFiles}") |
| 64 | + logging.info(f"microphoneData:{microphoneData}") |
| 65 | + logging.info(f"task: {task}") |
| 66 | + logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}") |
| 67 | + logging.info(f"chunk_length_s: {chunk_length_s}") |
| 68 | + logging.info(f"batch_size: {batch_size}") |
| 69 | + |
| 70 | + if last_model == None: |
| 71 | + logging.info("first model") |
| 72 | + progress(0.1, desc="Loading Model..") |
| 73 | + pipe = create_pipe(modelName, flash) |
| 74 | + elif modelName != last_model: |
| 75 | + logging.info("new model") |
| 76 | + torch.cuda.empty_cache() |
| 77 | + progress(0.1, desc="Loading Model..") |
| 78 | + pipe = create_pipe(modelName, flash) |
| 79 | + else: |
| 80 | + logging.info("Model not changed") |
| 81 | + last_model = modelName |
| 82 | + |
| 83 | + srt_sub = Subtitle("srt") |
| 84 | + vtt_sub = Subtitle("vtt") |
| 85 | + txt_sub = Subtitle("txt") |
| 86 | + |
| 87 | + files = [] |
| 88 | + if multipleFiles: |
| 89 | + files+=multipleFiles |
| 90 | + if urlData: |
| 91 | + files.append(urlData) |
| 92 | + if microphoneData: |
| 93 | + files.append(microphoneData) |
| 94 | + logging.info(files) |
| 95 | + |
| 96 | + generate_kwargs = {} |
| 97 | + if languageName != "Automatic Detection" and modelName.endswith(".en") == False: |
| 98 | + generate_kwargs["language"] = languageName |
| 99 | + if modelName.endswith(".en") == False: |
| 100 | + generate_kwargs["task"] = task |
| 101 | + |
| 102 | + files_out = [] |
| 103 | + for file in progress.tqdm(files, desc="Working..."): |
| 104 | + start_time = time.time() |
| 105 | + logging.info(file) |
| 106 | + outputs = pipe( |
| 107 | + file, |
| 108 | + chunk_length_s=chunk_length_s,#30 |
| 109 | + batch_size=batch_size,#24 |
| 110 | + generate_kwargs=generate_kwargs, |
| 111 | + return_timestamps=True, |
| 112 | + ) |
| 113 | + logging.debug(outputs) |
| 114 | + logging.info(print(f"transcribe: {time.time() - start_time} sec.")) |
| 115 | + |
| 116 | + file_out = file.split('/')[-1] |
| 117 | + srt = srt_sub.get_subtitle(outputs["chunks"]) |
| 118 | + vtt = vtt_sub.get_subtitle(outputs["chunks"]) |
| 119 | + txt = txt_sub.get_subtitle(outputs["chunks"]) |
| 120 | + write_file(file_out+".srt",srt) |
| 121 | + write_file(file_out+".vtt",vtt) |
| 122 | + write_file(file_out+".txt",txt) |
| 123 | + files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"] |
| 124 | + |
| 125 | + progress(1, desc="Completed!") |
| 126 | + |
| 127 | + return files_out, vtt, txt |
| 128 | + |
| 129 | + |
| 130 | + with gr.Blocks(title="Insanely Fast Whisper") as demo: |
| 131 | + description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn" |
| 132 | + article = "Read the [documentation here](https://github.com/Vaibhavs10/insanely-fast-whisper#cli-options)." |
| 133 | + whisper_models = [ |
| 134 | + "openai/whisper-tiny", "openai/whisper-tiny.en", |
| 135 | + "openai/whisper-base", "openai/whisper-base.en", |
| 136 | + "openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en", |
| 137 | + "openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en", |
| 138 | + "openai/whisper-large", |
| 139 | + "openai/whisper-large-v1", |
| 140 | + "openai/whisper-large-v2", "distil-whisper/distil-large-v2", |
| 141 | + "openai/whisper-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2", |
| 142 | + ] |
| 143 | + waveform_options=gr.WaveformOptions( |
| 144 | + waveform_color="#01C6FF", |
| 145 | + waveform_progress_color="#0066B4", |
| 146 | + skip_length=2, |
| 147 | + show_controls=False, |
| 148 | + ) |
| 149 | + |
| 150 | + simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress, |
| 151 | + description=description, |
| 152 | + article=article, |
| 153 | + inputs=[ |
| 154 | + gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive = True,), |
| 155 | + gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive = True,), |
| 156 | + gr.Text(label="URL", info="(YouTube, etc.)", interactive = True), |
| 157 | + gr.File(label="Upload Files", file_count="multiple"), |
| 158 | + gr.Audio(sources=["microphone"], type="filepath", label="Microphone Input", waveform_options = waveform_options), |
| 159 | + gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive = True), |
| 160 | + gr.Checkbox(label='Flash',info='Use Flash Attention 2'), |
| 161 | + gr.Number(label='chunk_length_s',value=30, interactive = True), |
| 162 | + gr.Number(label='batch_size',value=24, interactive = True) |
| 163 | + ], outputs=[ |
| 164 | + gr.File(label="Download"), |
| 165 | + gr.Text(label="Transcription"), |
| 166 | + gr.Text(label="Segments") |
| 167 | + ] |
| 168 | + ) |
| 169 | + |
| 170 | + if __name__ == "__main__": |
| 171 | + demo.launch() |
| 172 | + |
languages.py ADDED
| @@ -0,0 +1,147 @@ | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
| | |
| 1 | + class Language(): |
| 2 | + def __init__(self, code, name): |
| 3 | + self.code = code |
| 4 | + self.name = name |
| 5 | + |
| 6 | + def __str__(self): |
| 7 | + return "Language(code={}, name={})".format(self.code, self.name) |
| 8 | + |
| 9 | + LANGUAGES = [ |
| 10 | + Language('en', 'English'), |
| 11 | + Language('zh', 'Chinese'), |
| 12 | + Language('de', 'German'), |
| 13 | + Language('es', 'Spanish'), |
| 14 | + Language('ru', 'Russian'), |
| 15 | + Language('ko', 'Korean'), |
| 16 | + Language('fr', 'French'), |
| 17 | + Language('ja', 'Japanese'), |
| 18 | + Language('pt', 'Portuguese'), |
| 19 | + Language('tr', 'Turkish'), |
| 20 | + Language('pl', 'Polish'), |
| 21 | + Language('ca', 'Catalan'), |
| 22 | + Language('nl', 'Dutch'), |
| 23 | + Language('ar', 'Arabic'), |
| 24 | + Language('sv', 'Swedish'), |
| 25 | + Language('it', 'Italian'), |
| 26 | + Language('id', 'Indonesian'), |
| 27 | + Language('hi', 'Hindi'), |
| 28 | + Language('fi', 'Finnish'), |
| 29 | + Language('vi', 'Vietnamese'), |
| 30 | + Language('he', 'Hebrew'), |
| 31 | + Language('uk', 'Ukrainian'), |
| 32 | + Language('el', 'Greek'), |
| 33 | + Language('ms', 'Malay'), |
| 34 | + Language('cs', 'Czech'), |
| 35 | + Language('ro', 'Romanian'), |
| 36 | + Language('da', 'Danish'), |
| 37 | + Language('hu', 'Hungarian'), |
| 38 | + Language('ta', 'Tamil'), |
| 39 | + Language('no', 'Norwegian'), |
| 40 | + Language('th', 'Thai'), |
| 41 | + Language('ur', 'Urdu'), |
| 42 | + Language('hr', 'Croatian'), |
| 43 | + Language('bg', 'Bulgarian'), |
| 44 | + Language('lt', 'Lithuanian'), |
| 45 | + Language('la', 'Latin'), |
| 46 | + Language('mi', 'Maori'), |
| 47 | + Language('ml', 'Malayalam'), |
| 48 | + Language('cy', 'Welsh'), |
| 49 | + Language('sk', 'Slovak'), |
| 50 | + Language('te', 'Telugu'), |
| 51 | + Language('fa', 'Persian'), |
| 52 | + Language('lv', 'Latvian'), |
| 53 | + Language('bn', 'Bengali'), |
| 54 | + Language('sr', 'Serbian'), |
| 55 | + Language('az', 'Azerbaijani'), |
| 56 | + Language('sl', 'Slovenian'), |
| 57 | + Language('kn', 'Kannada'), |
| 58 | + Language('et', 'Estonian'), |
| 59 | + Language('mk', 'Macedonian'), |
| 60 | + Language('br', 'Breton'), |
| 61 | + Language('eu', 'Basque'), |
| 62 | + Language('is', 'Icelandic'), |
| 63 | + Language('hy', 'Armenian'), |
| 64 | + Language('ne', 'Nepali'), |
| 65 | + Language('mn', 'Mongolian'), |
| 66 | + Language('bs', 'Bosnian'), |
| 67 | + Language('kk', 'Kazakh'), |
| 68 | + Language('sq', 'Albanian'), |
| 69 | + Language('sw', 'Swahili'), |
| 70 | + Language('gl', 'Galician'), |
| 71 | + Language('mr', 'Marathi'), |
| 72 | + Language('pa', 'Punjabi'), |
| 73 | + Language('si', 'Sinhala'), |
| 74 | + Language('km', 'Khmer'), |
| 75 | + Language('sn', 'Shona'), |
| 76 | + Language('yo', 'Yoruba'), |
| 77 | + Language('so', 'Somali'), |
| 78 | + Language('af', 'Afrikaans'), |
| 79 | + Language('oc', 'Occitan'), |
| 80 | + Language('ka', 'Georgian'), |
| 81 | + Language('be', 'Belarusian'), |
| 82 | + Language('tg', 'Tajik'), |
| 83 | + Language('sd', 'Sindhi'), |
| 84 | + Language('gu', 'Gujarati'), |
| 85 | + Language('am', 'Amharic'), |
| 86 | + Language('yi', 'Yiddish'), |
| 87 | + Language('lo', 'Lao'), |
| 88 | + Language('uz', 'Uzbek'), |
| 89 | + Language('fo', 'Faroese'), |
| 90 | + Language('ht', 'Haitian creole'), |
| 91 | + Language('ps', 'Pashto'), |
| 92 | + Language('tk', 'Turkmen'), |
| 93 | + Language('nn', 'Nynorsk'), |
| 94 | + Language('mt', 'Maltese'), |
| 95 | + Language('sa', 'Sanskrit'), |
| 96 | + Language('lb', 'Luxembourgish'), |
| 97 | + Language('my', 'Myanmar'), |
| 98 | + Language('bo', 'Tibetan'), |
| 99 | + Language('tl', 'Tagalog'), |
| 100 | + Language('mg', 'Malagasy'), |
| 101 | + Language('as', 'Assamese'), |
| 102 | + Language('tt', 'Tatar'), |
| 103 | + Language('haw', 'Hawaiian'), |
| 104 | + Language('ln', 'Lingala'), |
| 105 | + Language('ha', 'Hausa'), |
| 106 | + Language('ba', 'Bashkir'), |
| 107 | + Language('jw', 'Javanese'), |
| 108 | + Language('su', 'Sundanese') |
| 109 | + ] |
| 110 | + |
| 111 | + _TO_LANGUAGE_CODE = { |
| 112 | + **{language.code: language for language in LANGUAGES}, |
| 113 | + "burmese": "my", |
| 114 | + "valencian": "ca", |
| 115 | + "flemish": "nl", |
| 116 | + "haitian": "ht", |
| 117 | + "letzeburgesch": "lb", |
| 118 | + "pushto": "ps", |
| 119 | + "panjabi": "pa", |
| 120 | + "moldavian": "ro", |
| 121 | + "moldovan": "ro", |
| 122 | + "sinhalese": "si", |
| 123 | + "castilian": "es", |
| 124 | + } |
| 125 | + |
| 126 | + _FROM_LANGUAGE_NAME = { |
| 127 | + **{language.name.lower(): language for language in LANGUAGES} |
| 128 | + } |
| 129 | + |
| 130 | + def get_language_from_code(language_code, default=None) -> Language: |
| 131 | + """Return the language name from the language code.""" |
| 132 | + return _TO_LANGUAGE_CODE.get(language_code, default) |
| 133 | + |
| 134 | + def get_language_from_name(language, default=None) -> Language: |
| 135 | + """Return the language code from the language name.""" |
| 136 | + return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default) |
| 137 | + |
| 138 | + def get_language_names(): |
| 139 | + """Return a list of language names.""" |
| 140 | + return [language.name for language in LANGUAGES] |
| 141 | + |
| 142 | + if __name__ == "__main__": |
| 143 | + # Test lookup |
| 144 | + print(get_language_from_code('en')) |
| 145 | + print(get_language_from_name('English')) |
| 146 | + |
| 147 | + print(get_language_names()) |
requirements.txt ADDED
| @@ -0,0 +1,5 @@ | |
| | |
| | |
| | |
| | |
| |
| | |
| 1 | + gradio |
| 2 | + --index-url https://download.pytorch.org/whl/cu121 |
| 3 | + torch>=2.1.1 |
| 4 | + torchvision |
| 5 | + torchaudio |
subtitle_manager.py ADDED
| @@ -0,0 +1,52 @@ | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
| | |
| 1 | + import re |
| 2 | + |
| 3 | + class Subtitle(): |
| 4 | + def __init__(self,ext="srt"): |
| 5 | + sub_dict = { |
| 6 | + "srt":{ |
| 7 | + "coma": ",", |
| 8 | + "header": "", |
| 9 | + "format": lambda i,segment : f"{i + 1}\n{self.timeformat(segment['timestamp'][0])} --> {self.timeformat(segment['timestamp'][1] if segment['timestamp'][1] != None else segment['timestamp'][0])}\n{segment['text']}\n\n", |
| 10 | + }, |
| 11 | + "vtt":{ |
| 12 | + "coma": ".", |
| 13 | + "header": "WebVTT\n\n", |
| 14 | + "format": lambda i,segment : f"{self.timeformat(segment['timestamp'][0])} --> {self.timeformat(segment['timestamp'][1] if segment['timestamp'][1] != None else segment['timestamp'][0])}\n{segment['text']}\n\n", |
| 15 | + }, |
| 16 | + "txt":{ |
| 17 | + "coma": "", |
| 18 | + "header": "", |
| 19 | + "format": lambda i,segment : f"{segment['text']}\n", |
| 20 | + }, |
| 21 | + } |
| 22 | + |
| 23 | + self.ext = ext |
| 24 | + self.coma = sub_dict[ext]["coma"] |
| 25 | + self.header = sub_dict[ext]["header"] |
| 26 | + self.format = sub_dict[ext]["format"] |
| 27 | + |
| 28 | + def timeformat(self,time): |
| 29 | + hours = time // 3600 |
| 30 | + minutes = (time - hours * 3600) // 60 |
| 31 | + seconds = time - hours * 3600 - minutes * 60 |
| 32 | + milliseconds = (time - int(time)) * 1000 |
| 33 | + return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}{self.coma}{int(milliseconds):03d}" |
| 34 | + |
| 35 | + def get_subtitle(self,segments): |
| 36 | + output = self.header |
| 37 | + for i, segment in enumerate(segments): |
| 38 | + if segment['text'].startswith(' '): |
| 39 | + segment['text'] = segment['text'][1:] |
| 40 | + try: |
| 41 | + output += self.format(i,segment) |
| 42 | + except Exception as e: |
| 43 | + print(e,segment) |
| 44 | + |
| 45 | + return output |
| 46 | + |
| 47 | + def write_subtitle(self, segments, output_file): |
| 48 | + output_file += "."+self.ext |
| 49 | + subtitle = self.get_subtitle(segments) |
| 50 | + |
| 51 | + with open(output_file, 'w', encoding='utf-8') as f: |
| 52 | + f.write(subtitle) |