|
12 | 12 | import shutil |
13 | 13 | import base64 |
14 | 14 | import os |
15 | | -import re |
| 15 | +import re |
| 16 | +import tempfile |
16 | 17 | from typing import Any, Dict |
17 | 18 | from threading import Thread |
18 | 19 | from haystack.telemetry import send_event_if_public_demo |
|
48 | 49 |
|
49 | 50 | def ask_gm_documents_dense_embedding(folder_path, process_content=False): |
50 | 51 | documents = [] |
51 | | - for dirpath, dirnames, filenames in os.walk(folder_path): |
52 | | - for filename in filenames: |
53 | | - if filename.endswith(".json"): |
54 | | - documents = doc_index.d_load_jsonl_file(os.path.join(dirpath, filename), process_content, documents) |
55 | | - elif filename.endswith(".xlsx"): |
56 | | - documents = doc_index.d_load_xlsx(os.path.join(dirpath, filename), process_content) |
57 | | - else: |
58 | | - print("{} is ignored. Will support this file format soon.".format(filename)) |
59 | | - continue |
60 | | - doc_index.persist_embedding(documents, "/tmp/ask_gm_dense_retrieval_chinese", |
61 | | - model_path="shibing624/text2vec-large-chinese") |
62 | | - doc_index.persist_embedding(documents, "/tmp/ask_gm_dense_retrieval_english", |
63 | | - model_path="hkunlp/instructor-large") |
| 52 | + with tempfile.TemporaryDirectory(dir="/tmp/my_subdirectory") as temp_dir: |
| 53 | + for dirpath, dirnames, filenames in os.walk(folder_path): |
| 54 | + for filename in filenames: |
| 55 | + if filename.endswith(".json"): |
| 56 | + documents = doc_index.d_load_jsonl_file(os.path.join(dirpath, filename), process_content, documents) |
| 57 | + elif filename.endswith(".xlsx"): |
| 58 | + documents = doc_index.d_load_xlsx(os.path.join(dirpath, filename), process_content) |
| 59 | + else: |
| 60 | + print("{} is ignored. Will support this file format soon.".format(filename)) |
| 61 | + continue |
| 62 | + doc_index.persist_embedding(documents, temp_dir, model_path="shibing624/text2vec-large-chinese") |
| 63 | + doc_index.persist_embedding(documents, temp_dir, model_path="hkunlp/instructor-large") |
64 | 64 |
|
65 | 65 | def ask_gm_documents_sparse_embedding(folder_path, process_content=False): |
66 | 66 | document_store = ElasticsearchDocumentStore(host="localhost", index="elastic_askgm_sparse", |
@@ -141,27 +141,31 @@ def ask_gm_documents_sparse_embedding(folder_path, process_content=False): |
141 | 141 | stop_token_ids.append(langchain_tok("。", return_tensors="pt").input_ids) |
142 | 142 | stop_token_ids.append(langchain_tok("!", return_tensors="pt").input_ids) |
143 | 143 | langchain_tok.pad_token = langchain_tok.eos_token |
144 | | -langchain_tok.add_special_tokens({'pad_token': '[PAD]'}) |
145 | | -if not os.path.exists("/tmp/young_pat_dense_retrieval"): |
146 | | - documents = doc_index.d_load_young_pat_xlsx("./doc/young_pat/pat.xlsx", True) |
147 | | - doc_index.persist_embedding(documents, "/tmp/young_pat_dense_retrieval", model_path="hkunlp/instructor-large") |
148 | | - |
149 | | -english_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large") |
150 | | -chinese_embeddings = HuggingFaceInstructEmbeddings(model_name="shibing624/text2vec-base-chinese") |
151 | | -young_pat_vectordb = Chroma(persist_directory="/tmp/young_pat_dense_retrieval", |
152 | | - embedding_function=english_embeddings) |
153 | | -young_pat_dense_retriever = young_pat_vectordb.as_retriever(search_type = "mmr", |
154 | | - search_kwargs = {"k": 2, "fetch_k": 5}) |
155 | | - |
156 | | -ask_gm_eng_vectordb = Chroma(persist_directory='/tmp/ask_gm_dense_retrieval_english', |
| 144 | +langchain_tok.add_special_tokens({'pad_token': '[PAD]'}) |
| 145 | +with tempfile.TemporaryDirectory(dir="/tmp/my_subdirectory") as temp_dir: |
| 146 | + if not os.path.exists(temp_dir): |
| 147 | + documents = doc_index.d_load_young_pat_xlsx("./doc/young_pat/pat.xlsx", True) |
| 148 | + doc_index.persist_embedding(documents, temp_dir, model_path="hkunlp/instructor-large") |
| 149 | + |
| 150 | +with tempfile.TemporaryDirectory(dir="/tmp/my_subdirectory") as temp_dir: |
| 151 | + english_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large") |
| 152 | + chinese_embeddings = HuggingFaceInstructEmbeddings(model_name="shibing624/text2vec-base-chinese") |
| 153 | + |
| 154 | + young_pat_vectordb = Chroma(persist_directory=temp_dir, |
| 155 | + embedding_function=english_embeddings) |
| 156 | + young_pat_dense_retriever = young_pat_vectordb.as_retriever(search_type="mmr", |
| 157 | + search_kwargs={"k": 2, "fetch_k": 5}) |
| 158 | + |
| 159 | + ask_gm_eng_vectordb = Chroma(persist_directory=temp_dir, |
157 | 160 | embedding_function=english_embeddings) |
158 | | -ask_gm_eng_retriever = ask_gm_eng_vectordb.as_retriever(search_type = "mmr", |
159 | | - search_kwargs = {"k": 2, "fetch_k": 5}) |
| 161 | + ask_gm_eng_retriever = ask_gm_eng_vectordb.as_retriever(search_type="mmr", |
| 162 | + search_kwargs={"k": 2, "fetch_k": 5}) |
160 | 163 |
|
161 | | -ask_gm_chn_vectordb = Chroma(persist_directory='/tmp/ask_gm_dense_retrieval_chinese', |
| 164 | + ask_gm_chn_vectordb = Chroma(persist_directory=temp_dir, |
162 | 165 | embedding_function=chinese_embeddings) |
163 | | -ask_gm_chn_retriever = ask_gm_chn_vectordb.as_retriever(search_type = "mmr", |
164 | | - search_kwargs = {"k": 2, "fetch_k": 5}) |
| 166 | + ask_gm_chn_retriever = ask_gm_chn_vectordb.as_retriever(search_type="mmr", |
| 167 | + search_kwargs={"k": 2, "fetch_k": 5}) |
| 168 | + |
165 | 169 |
|
166 | 170 | class StopOnTokens(StoppingCriteria): |
167 | 171 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
@@ -453,13 +457,15 @@ def query(request: QueryRequest): |
453 | 457 | young_pat_pipeline.add_node(component=shaper, name="Shaper", inputs=["Reranker"]) |
454 | 458 | young_pat_pipeline.add_node(component=prompt, name="Prompter", inputs=["Shaper"]) |
455 | 459 | result = _process_request(young_pat_pipeline, request) |
456 | | - elif domain == "Customized": |
457 | | - if request.blob: |
458 | | - file_content = base64.b64decode(request.blob) |
459 | | - random_suffix = str(uuid.uuid4().hex) |
460 | | - file_path = f"/tmp/customized_doc_{random_suffix}" + request.filename |
461 | | - with open(file_path, "wb") as f: |
462 | | - f.write(file_content) |
| 460 | + elif domain == "Customized": |
| 461 | +if request.blob: |
| 462 | + file_content = base64.b64decode(request.blob) |
| 463 | + random_suffix = str(uuid.uuid4().hex) |
| 464 | + sanitized_filename = os.path.basename(request.filename) |
| 465 | + file_path = f"/tmp/customized_doc_{random_suffix}_{sanitized_filename}" |
| 466 | + with open(file_path, "wb") as f: |
| 467 | + f.write(file_content) |
| 468 | + |
463 | 469 |
|
464 | 470 | if request.filename.endswith("md"): |
465 | 471 | converter = MarkdownConverter() |
|
0 commit comments