Skip to content

Commit 91e985a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LLM - Grounding - Added support for the disable_attribution grounding parameter
PiperOrigin-RevId: 580285757
1 parent 791eff5 commit 91e985a

File tree

2 files changed

+91
-32
lines changed

2 files changed

+91
-32
lines changed

tests/unit/aiplatform/test_language_models.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,10 @@
183183
"citations": [
184184
{"url": "url1", "startIndex": 1, "endIndex": 2},
185185
{"url": "url2", "startIndex": 3, "endIndex": 4},
186-
]
186+
],
187+
"searchQueries": [
188+
"searchQuery",
189+
],
187190
},
188191
"content": """
189192
Ingredients:
@@ -211,7 +214,8 @@
211214
"license": None,
212215
"publication_date": None,
213216
},
214-
]
217+
],
218+
"search_queries": ["searchQuery"],
215219
}
216220

217221
_TEST_TEXT_GENERATION_PREDICTION = {
@@ -332,7 +336,8 @@
332336
"endIndex": 2,
333337
"url": "url1",
334338
}
335-
]
339+
],
340+
"searchQueries": ["searchQuery1"],
336341
},
337342
{
338343
"citations": [
@@ -341,7 +346,8 @@
341346
"endIndex": 4,
342347
"url": "url2",
343348
}
344-
]
349+
],
350+
"searchQueries": ["searchQuery2"],
345351
},
346352
],
347353
"candidates": [
@@ -396,10 +402,12 @@
396402
"publication_date": None,
397403
},
398404
],
405+
"search_queries": ["searchQuery1"],
399406
}
400407

401408
_EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE = {
402409
"citations": [],
410+
"search_queries": [],
403411
}
404412

405413
_TEST_CHAT_PREDICTION_STREAMING = [
@@ -1567,12 +1575,13 @@ def test_text_generation_multiple_candidates_grounding(self):
15671575
"collections/default_collection/dataStores/test_datastore"
15681576
)
15691577
expected_grounding_sources = [
1570-
{"sources": [{"type": "WEB"}]},
1578+
{"sources": [{"type": "WEB", "disableAttribution": False}]},
15711579
{
15721580
"sources": [
15731581
{
1574-
"type": "ENTERPRISE",
1575-
"enterpriseDatastore": datastore_path,
1582+
"type": "VERTEX_AI_SEARCH",
1583+
"vertexAiSearchDatastore": datastore_path,
1584+
"disableAttribution": False,
15761585
}
15771586
]
15781587
},
@@ -1680,12 +1689,20 @@ async def test_text_generation_multiple_candidates_grounding_async(self):
16801689
"collections/default_collection/dataStores/test_datastore"
16811690
)
16821691
expected_grounding_sources = [
1683-
{"sources": [{"type": "WEB"}]},
16841692
{
16851693
"sources": [
16861694
{
1687-
"type": "ENTERPRISE",
1688-
"enterpriseDatastore": datastore_path,
1695+
"type": "WEB",
1696+
"disableAttribution": False,
1697+
}
1698+
]
1699+
},
1700+
{
1701+
"sources": [
1702+
{
1703+
"type": "VERTEX_AI_SEARCH",
1704+
"vertexAiSearchDatastore": datastore_path,
1705+
"disableAttribution": False,
16891706
}
16901707
]
16911708
},
@@ -2416,12 +2433,20 @@ def test_chat(self):
24162433
"collections/default_collection/dataStores/test_datastore"
24172434
)
24182435
expected_grounding_sources = [
2419-
{"sources": [{"type": "WEB"}]},
24202436
{
24212437
"sources": [
24222438
{
2423-
"type": "ENTERPRISE",
2424-
"enterpriseDatastore": datastore_path,
2439+
"type": "WEB",
2440+
"disableAttribution": False,
2441+
}
2442+
]
2443+
},
2444+
{
2445+
"sources": [
2446+
{
2447+
"type": "VERTEX_AI_SEARCH",
2448+
"vertexAiSearchDatastore": datastore_path,
2449+
"disableAttribution": False,
24252450
}
24262451
]
24272452
},
@@ -2461,12 +2486,20 @@ def test_chat(self):
24612486
"collections/default_collection/dataStores/test_datastore"
24622487
)
24632488
expected_grounding_sources = [
2464-
{"sources": [{"type": "WEB"}]},
24652489
{
24662490
"sources": [
24672491
{
2468-
"type": "ENTERPRISE",
2469-
"enterpriseDatastore": datastore_path,
2492+
"type": "WEB",
2493+
"disableAttribution": False,
2494+
}
2495+
]
2496+
},
2497+
{
2498+
"sources": [
2499+
{
2500+
"type": "VERTEX_AI_SEARCH",
2501+
"vertexAiSearchDatastore": datastore_path,
2502+
"disableAttribution": False,
24702503
}
24712504
]
24722505
},
@@ -2537,12 +2570,20 @@ async def test_chat_async(self):
25372570
"collections/default_collection/dataStores/test_datastore"
25382571
)
25392572
expected_grounding_sources = [
2540-
{"sources": [{"type": "WEB"}]},
25412573
{
25422574
"sources": [
25432575
{
2544-
"type": "ENTERPRISE",
2545-
"enterpriseDatastore": datastore_path,
2576+
"type": "WEB",
2577+
"disableAttribution": False,
2578+
}
2579+
]
2580+
},
2581+
{
2582+
"sources": [
2583+
{
2584+
"type": "VERTEX_AI_SEARCH",
2585+
"vertexAiSearchDatastore": datastore_path,
2586+
"disableAttribution": False,
25462587
}
25472588
]
25482589
},
@@ -2586,12 +2627,20 @@ async def test_chat_async(self):
25862627
"collections/default_collection/dataStores/test_datastore"
25872628
)
25882629
expected_grounding_sources = [
2589-
{"sources": [{"type": "WEB"}]},
25902630
{
25912631
"sources": [
25922632
{
2593-
"type": "ENTERPRISE",
2594-
"enterpriseDatastore": datastore_path,
2633+
"type": "WEB",
2634+
"disableAttribution": False,
2635+
}
2636+
]
2637+
},
2638+
{
2639+
"sources": [
2640+
{
2641+
"type": "VERTEX_AI_SEARCH",
2642+
"vertexAiSearchDatastore": datastore_path,
2643+
"disableAttribution": False,
25952644
}
25962645
]
25972646
},

vertexai/language_models/_language_models.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -709,12 +709,16 @@ def _to_grounding_source_dict(self) -> Dict[str, Any]:
709709

710710
@dataclasses.dataclass
711711
class WebSearch(_GroundingSourceBase):
712-
"""WebSearch represents a grounding source using public web search."""
712+
"""WebSearch represents a grounding source using public web search.
713+
Attributes:
714+
disable_attribution: If set to `True`, skip finding claim attributions (i.e not generate grounding citation). Default: False.
715+
"""
713716

717+
disable_attribution: bool = False
714718
_type: str = dataclasses.field(default="WEB", init=False, repr=False)
715719

716720
def _to_grounding_source_dict(self) -> Dict[str, Any]:
717-
return {"type": self._type}
721+
return {"type": self._type, "disableAttribution": self.disable_attribution}
718722

719723

720724
@dataclasses.dataclass
@@ -723,16 +727,18 @@ class VertexAISearch(_GroundingSourceBase):
723727
Attributes:
724728
data_store_id: Data store ID of the Vertex AI Search datastore.
725729
location: GCP multi region where you have set up your Vertex AI Search data store. Possible values can be `global`, `us`, `eu`, etc.
726-
Learn more about Vertex AI Search location here:
727-
https://cloud.google.com/generative-ai-app-builder/docs/locations
730+
Learn more about Vertex AI Search location here:
731+
https://cloud.google.com/generative-ai-app-builder/docs/locations
728732
project: The project where you have set up your Vertex AI Search.
729-
If not specified, will assume that your Vertex AI Search is within your current project.
733+
If not specified, will assume that your Vertex AI Search is within your current project.
734+
disable_attribution: If set to `True`, skip finding claim attributions (i.e not generate grounding citation). Default: False.
730735
"""
731736

732737
data_store_id: str
733738
location: str
734739
project: Optional[str] = None
735-
_type: str = dataclasses.field(default="ENTERPRISE", init=False, repr=False)
740+
disable_attribution: bool = False
741+
_type: str = dataclasses.field(default="VERTEX_AI_SEARCH", init=False, repr=False)
736742

737743
def _get_datastore_path(self) -> str:
738744
_project = self.project or aiplatform_initializer.global_config.project
@@ -742,7 +748,11 @@ def _get_datastore_path(self) -> str:
742748
)
743749

744750
def _to_grounding_source_dict(self) -> Dict[str, Any]:
745-
return {"type": self._type, "enterpriseDatastore": self._get_datastore_path()}
751+
return {
752+
"type": self._type,
753+
"vertexAiSearchDatastore": self._get_datastore_path(),
754+
"disableAttribution": self.disable_attribution,
755+
}
746756

747757

748758
@dataclasses.dataclass
@@ -790,6 +800,7 @@ class GroundingMetadata:
790800
"""
791801

792802
citations: Optional[List[GroundingCitation]] = None
803+
search_queries: Optional[List[str]] = None
793804

794805
def _parse_citation_from_dict(
795806
self, citation_dict_camel: Dict[str, Any]
@@ -819,6 +830,7 @@ def __init__(self, response: Optional[Dict[str, Any]] = {}):
819830
self._parse_citation_from_dict(citation)
820831
for citation in response.get("citations", [])
821832
]
833+
self.search_queries = response.get("searchQueries", [])
822834

823835

824836
@dataclasses.dataclass
@@ -1521,9 +1533,7 @@ def _prepare_text_embedding_request(
15211533
A `_MultiInstancePredictionRequest` object.
15221534
"""
15231535
if isinstance(texts, str) or not isinstance(texts, Sequence):
1524-
raise TypeError(
1525-
"The `texts` argument must be a list, not a single string."
1526-
)
1536+
raise TypeError("The `texts` argument must be a list, not a single string.")
15271537
instances = []
15281538
for text in texts:
15291539
if isinstance(text, TextEmbeddingInput):

0 commit comments

Comments
 (0)