Skip to content

Commit 03a7861

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Update v1 sdk to support llmparser in import file functions
PiperOrigin-RevId: 756617457
1 parent 869bea0 commit 03a7861

File tree

6 files changed

+131
-17
lines changed

6 files changed

+131
-17
lines changed

tests/unit/vertex_rag/test_rag_constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vertexai.rag import (
2222
Filter,
2323
LayoutParserConfig,
24+
LlmParserConfig,
2425
LlmRanker,
2526
Pinecone,
2627
RagCorpus,
@@ -629,3 +630,26 @@
629630
llm_ranker=LlmRanker(model_name="test-llm-ranker"),
630631
),
631632
)
633+
TEST_LLM_PARSER_CONFIG = LlmParserConfig(
634+
model_name="gemini-1.5-pro-002",
635+
max_parsing_requests_per_min=500,
636+
custom_parsing_prompt="test-custom-parsing-prompt",
637+
)
638+
639+
640+
TEST_IMPORT_FILES_CONFIG_LLM_PARSER = ImportRagFilesConfig(
641+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER
642+
)
643+
644+
TEST_IMPORT_FILES_CONFIG_LLM_PARSER.rag_file_parsing_config = RagFileParsingConfig(
645+
llm_parser=RagFileParsingConfig.LlmParser(
646+
model_name="gemini-1.5-pro-002",
647+
max_parsing_requests_per_min=500,
648+
custom_parsing_prompt="test-custom-parsing-prompt",
649+
)
650+
)
651+
652+
TEST_IMPORT_REQUEST_LLM_PARSER = ImportRagFilesRequest(
653+
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
654+
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_LLM_PARSER,
655+
)

tests/unit/vertex_rag/test_rag_data.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def test_prepare_import_files_request_valid_layout_parser_with_processor_path(se
931931
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
932932
paths=[test_rag_constants.TEST_DRIVE_FOLDER],
933933
transformation_config=create_transformation_config(),
934-
parser=test_rag_constants.TEST_LAYOUT_PARSER_WITH_PROCESSOR_PATH_CONFIG,
934+
layout_parser=test_rag_constants.TEST_LAYOUT_PARSER_WITH_PROCESSOR_PATH_CONFIG,
935935
)
936936
import_files_request_eq(
937937
request,
@@ -945,7 +945,7 @@ def test_prepare_import_files_request_valid_layout_parser_with_processor_version
945945
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
946946
paths=[test_rag_constants.TEST_DRIVE_FOLDER],
947947
transformation_config=create_transformation_config(),
948-
parser=test_rag_constants.TEST_LAYOUT_PARSER_WITH_PROCESSOR_VERSION_PATH_CONFIG,
948+
layout_parser=test_rag_constants.TEST_LAYOUT_PARSER_WITH_PROCESSOR_VERSION_PATH_CONFIG,
949949
)
950950
import_files_request_eq(
951951
request,
@@ -961,10 +961,45 @@ def test_prepare_import_files_request_invalid_layout_parser_name(self):
961961
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
962962
paths=[test_rag_constants.TEST_DRIVE_FOLDER],
963963
transformation_config=create_transformation_config(),
964-
parser=layout_parser,
964+
layout_parser=layout_parser,
965965
)
966966
e.match("processor_name must be of the format")
967967

968+
def test_prepare_import_files_request_llm_parser(self):
969+
request = prepare_import_files_request(
970+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
971+
paths=[test_rag_constants.TEST_DRIVE_FOLDER],
972+
transformation_config=create_transformation_config(),
973+
llm_parser=test_rag_constants.TEST_LLM_PARSER_CONFIG,
974+
)
975+
import_files_request_eq(
976+
request,
977+
test_rag_constants.TEST_IMPORT_REQUEST_LLM_PARSER,
978+
)
979+
980+
def test_layout_parser_and_llm_parser_both_set_error(self):
981+
with pytest.raises(ValueError) as e:
982+
rag.import_files(
983+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
984+
paths=[test_rag_constants.TEST_DRIVE_FOLDER],
985+
transformation_config=create_transformation_config(),
986+
layout_parser=test_rag_constants.TEST_LAYOUT_PARSER_WITH_PROCESSOR_PATH_CONFIG,
987+
llm_parser=test_rag_constants.TEST_LLM_PARSER_CONFIG,
988+
)
989+
e.match("Only one of layout_parser or llm_parser may be passed in at a time")
990+
991+
@pytest.mark.asyncio
992+
async def test_layout_parser_and_llm_parser_both_set_error_async(self):
993+
with pytest.raises(ValueError) as e:
994+
await rag.import_files_async(
995+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
996+
paths=[test_rag_constants.TEST_DRIVE_FOLDER],
997+
transformation_config=create_transformation_config(),
998+
layout_parser=test_rag_constants.TEST_LAYOUT_PARSER_WITH_PROCESSOR_PATH_CONFIG,
999+
llm_parser=test_rag_constants.TEST_LLM_PARSER_CONFIG,
1000+
)
1001+
e.match("Only one of layout_parser or llm_parser may be passed in at a time")
1002+
9681003
def test_set_embedding_model_config_set_both_error(self):
9691004
embedding_model_config = rag.RagEmbeddingModelConfig(
9701005
vertex_prediction_endpoint=rag.VertexPredictionEndpoint(

vertexai/rag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
JiraQuery,
4444
JiraSource,
4545
LayoutParserConfig,
46+
LlmParserConfig,
4647
LlmRanker,
4748
Pinecone,
4849
RagCorpus,
@@ -71,6 +72,7 @@
7172
"JiraQuery",
7273
"JiraSource",
7374
"LayoutParserConfig",
75+
"LlmParserConfig",
7476
"LlmRanker",
7577
"Pinecone",
7678
"RagCorpus",

vertexai/rag/rag_data.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from vertexai.rag.utils.resources import (
4646
JiraSource,
4747
LayoutParserConfig,
48+
LlmParserConfig,
4849
RagCorpus,
4950
RagFile,
5051
RagVectorDbConfig,
@@ -433,7 +434,8 @@ def import_files(
433434
max_embedding_requests_per_min: int = 1000,
434435
import_result_sink: Optional[str] = None,
435436
partial_failures_sink: Optional[str] = None,
436-
parser: Optional[LayoutParserConfig] = None,
437+
layout_parser: Optional[LayoutParserConfig] = None,
438+
llm_parser: Optional[LlmParserConfig] = None,
437439
) -> ImportRagFilesResponse:
438440
"""
439441
Import files to an existing RagCorpus, wait until completion.
@@ -573,6 +575,10 @@ def import_files(
573575
raise ValueError("Only one of source or paths must be passed in at a time")
574576
if source is None and paths is None:
575577
raise ValueError("One of source or paths must be passed in")
578+
if layout_parser is not None and llm_parser is not None:
579+
raise ValueError(
580+
"Only one of layout_parser or llm_parser may be passed in at a time"
581+
)
576582
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
577583
request = _gapic_utils.prepare_import_files_request(
578584
corpus_name=corpus_name,
@@ -582,7 +588,8 @@ def import_files(
582588
max_embedding_requests_per_min=max_embedding_requests_per_min,
583589
import_result_sink=import_result_sink,
584590
partial_failures_sink=partial_failures_sink,
585-
parser=parser,
591+
layout_parser=layout_parser,
592+
llm_parser=llm_parser,
586593
)
587594
client = _gapic_utils.create_rag_data_service_client()
588595
try:
@@ -601,7 +608,8 @@ async def import_files_async(
601608
max_embedding_requests_per_min: int = 1000,
602609
import_result_sink: Optional[str] = None,
603610
partial_failures_sink: Optional[str] = None,
604-
parser: Optional[LayoutParserConfig] = None,
611+
layout_parser: Optional[LayoutParserConfig] = None,
612+
llm_parser: Optional[LlmParserConfig] = None,
605613
) -> operation_async.AsyncOperation:
606614
"""
607615
Import files to an existing RagCorpus asynchronously.
@@ -741,6 +749,10 @@ async def import_files_async(
741749
raise ValueError("Only one of source or paths must be passed in at a time")
742750
if source is None and paths is None:
743751
raise ValueError("One of source or paths must be passed in")
752+
if layout_parser is not None and llm_parser is not None:
753+
raise ValueError(
754+
"Only one of layout_parser or llm_parser may be passed in at a time"
755+
)
744756
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
745757
request = _gapic_utils.prepare_import_files_request(
746758
corpus_name=corpus_name,
@@ -750,7 +762,8 @@ async def import_files_async(
750762
max_embedding_requests_per_min=max_embedding_requests_per_min,
751763
import_result_sink=import_result_sink,
752764
partial_failures_sink=partial_failures_sink,
753-
parser=parser,
765+
layout_parser=layout_parser,
766+
llm_parser=llm_parser,
754767
)
755768
async_client = _gapic_utils.create_rag_data_service_async_client()
756769
try:

vertexai/rag/utils/_gapic_utils.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from vertexai.rag.utils.resources import (
4343
LayoutParserConfig,
44+
LlmParserConfig,
4445
Pinecone,
4546
RagCorpus,
4647
RagEmbeddingModelConfig,
@@ -381,30 +382,45 @@ def prepare_import_files_request(
381382
max_embedding_requests_per_min: int = 1000,
382383
import_result_sink: Optional[str] = None,
383384
partial_failures_sink: Optional[str] = None,
384-
parser: Optional[LayoutParserConfig] = None,
385+
layout_parser: Optional[LayoutParserConfig] = None,
386+
llm_parser: Optional[LlmParserConfig] = None,
385387
) -> ImportRagFilesRequest:
386388
if len(corpus_name.split("/")) != 6:
387389
raise ValueError(
388390
"corpus_name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`"
389391
)
390392

391393
rag_file_parsing_config = RagFileParsingConfig()
392-
if parser is not None:
394+
if layout_parser is not None:
393395
if (
394-
re.fullmatch(_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX, parser.processor_name)
396+
re.fullmatch(
397+
_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX,
398+
layout_parser.processor_name,
399+
)
395400
is None
396401
):
397402
raise ValueError(
398-
"processor_name must be of the format "
399-
"`projects/{project_id}/locations/{location}/processors/{processor_id}`"
400-
"or "
401-
"`projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`, "
402-
f"got {parser.processor_name!r}"
403+
"processor_name must be of the format"
404+
" `projects/{project_id}/locations/{location}/processors/{processor_id}`or"
405+
" `projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`,"
406+
f" got {layout_parser.processor_name!r}"
403407
)
404408
rag_file_parsing_config.layout_parser = RagFileParsingConfig.LayoutParser(
405-
processor_name=parser.processor_name,
406-
max_parsing_requests_per_min=parser.max_parsing_requests_per_min,
409+
processor_name=layout_parser.processor_name,
410+
max_parsing_requests_per_min=layout_parser.max_parsing_requests_per_min,
411+
)
412+
if llm_parser is not None:
413+
rag_file_parsing_config.llm_parser = RagFileParsingConfig.LlmParser(
414+
model_name=llm_parser.model_name
407415
)
416+
if llm_parser.max_parsing_requests_per_min is not None:
417+
rag_file_parsing_config.llm_parser.max_parsing_requests_per_min = (
418+
llm_parser.max_parsing_requests_per_min
419+
)
420+
if llm_parser.custom_parsing_prompt is not None:
421+
rag_file_parsing_config.llm_parser.custom_parsing_prompt = (
422+
llm_parser.custom_parsing_prompt
423+
)
408424

409425
chunk_size = 1024
410426
chunk_overlap = 200

vertexai/rag/utils/resources.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,27 @@ class LayoutParserConfig:
445445

446446
processor_name: str
447447
max_parsing_requests_per_min: Optional[int] = None
448+
449+
450+
@dataclasses.dataclass
451+
class LlmParserConfig:
452+
"""Configuration for the Document AI Layout Parser Processor.
453+
454+
Attributes:
455+
model_name (str):
456+
The full resource name of a Vertex AI model. Format:
457+
- `projects/{project_id}/locations/{location}/publishers/google/models/{model_id}`
458+
- `projects/{project_id}/locations/{location}/models/{model_id}`
459+
max_parsing_requests_per_min (int):
460+
The maximum number of requests the job is allowed to make to the
461+
Vertex AI model per minute. Consult
462+
https://cloud.google.com/vertex-ai/generative-ai/docs/quotas and
463+
the Quota page for your project to set an appropriate value here.
464+
If unspecified, a default value of 120 QPM will be used.
465+
custom_parsing_prompt (str):
466+
A custom prompt to use for parsing.
467+
"""
468+
469+
model_name: str
470+
max_parsing_requests_per_min: Optional[int] = None
471+
custom_parsing_prompt: Optional[str] = None

0 commit comments

Comments
 (0)