|
38 | 38 | from google.cloud.aiplatform import utils as aiplatform_utils |
39 | 39 | from google.cloud.aiplatform_v1beta1 import types as aiplatform_types |
40 | 40 | from google.cloud.aiplatform_v1beta1.services import prediction_service |
| 41 | +from google.cloud.aiplatform_v1beta1.services import llm_utility_service |
41 | 42 | from google.cloud.aiplatform_v1beta1.types import ( |
42 | 43 | content as gapic_content_types, |
43 | 44 | ) |
44 | 45 | from google.cloud.aiplatform_v1beta1.types import ( |
45 | 46 | prediction_service as gapic_prediction_service_types, |
| 47 | + llm_utility_service as gapic_llm_utility_service_types, |
46 | 48 | ) |
47 | 49 | from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types |
48 | 50 | from google.protobuf import json_format |
@@ -385,6 +387,34 @@ def _prediction_async_client( |
385 | 387 | ) |
386 | 388 | return self._prediction_async_client_value |
387 | 389 |
|
| 390 | + @property |
| 391 | + def _llm_utility_client(self) -> llm_utility_service.LlmUtilityServiceClient: |
| 392 | + # Switch to @functools.cached_property once its available. |
| 393 | + if not getattr(self, "_llm_utility_client_value", None): |
| 394 | + self._llm_utility_client_value = ( |
| 395 | + aiplatform_initializer.global_config.create_client( |
| 396 | + client_class=llm_utility_service.LlmUtilityServiceClient, |
| 397 | + location_override=self._location, |
| 398 | + prediction_client=True, |
| 399 | + ) |
| 400 | + ) |
| 401 | + return self._llm_utility_client_value |
| 402 | + |
| 403 | + @property |
| 404 | + def _llm_utility_async_client( |
| 405 | + self, |
| 406 | + ) -> llm_utility_service.LlmUtilityServiceAsyncClient: |
| 407 | + # Switch to @functools.cached_property once its available. |
| 408 | + if not getattr(self, "_llm_utility_async_client_value", None): |
| 409 | + self._llm_utility_async_client_value = ( |
| 410 | + aiplatform_initializer.global_config.create_client( |
| 411 | + client_class=llm_utility_service.LlmUtilityServiceAsyncClient, |
| 412 | + location_override=self._location, |
| 413 | + prediction_client=True, |
| 414 | + ) |
| 415 | + ) |
| 416 | + return self._llm_utility_async_client_value |
| 417 | + |
388 | 418 | def _prepare_request( |
389 | 419 | self, |
390 | 420 | contents: ContentsType, |
@@ -790,6 +820,60 @@ async def count_tokens_async( |
790 | 820 | ) |
791 | 821 | ) |
792 | 822 |
|
| 823 | + def compute_tokens( |
| 824 | + self, contents: ContentsType |
| 825 | + ) -> gapic_llm_utility_service_types.ComputeTokensResponse: |
| 826 | + """Counts tokens. |
| 827 | +
|
| 828 | + Args: |
| 829 | + contents: Contents to send to the model. |
| 830 | + Supports either a list of Content objects (passing a multi-turn conversation) |
| 831 | + or a value that can be converted to a single Content object (passing a single message). |
| 832 | + Supports |
| 833 | + * str, Image, Part, |
| 834 | + * List[Union[str, Image, Part]], |
| 835 | + * List[Content] |
| 836 | +
|
| 837 | + Returns: |
| 838 | + A CountTokensResponse object that has the following attributes: |
| 839 | + total_tokens: The total number of tokens counted across all instances from the request. |
| 840 | + total_billable_characters: The total number of billable characters counted across all instances from the request. |
| 841 | + """ |
| 842 | + return self._llm_utility_client.compute_tokens( |
| 843 | + request=gapic_llm_utility_service_types.ComputeTokensRequest( |
| 844 | + endpoint=self._prediction_resource_name, |
| 845 | + model=self._prediction_resource_name, |
| 846 | + contents=self._prepare_request(contents=contents).contents, |
| 847 | + ) |
| 848 | + ) |
| 849 | + |
| 850 | + async def compute_tokens_async( |
| 851 | + self, contents: ContentsType |
| 852 | + ) -> gapic_llm_utility_service_types.ComputeTokensResponse: |
| 853 | + """Counts tokens asynchronously. |
| 854 | +
|
| 855 | + Args: |
| 856 | + contents: Contents to send to the model. |
| 857 | + Supports either a list of Content objects (passing a multi-turn conversation) |
| 858 | + or a value that can be converted to a single Content object (passing a single message). |
| 859 | + Supports |
| 860 | + * str, Image, Part, |
| 861 | + * List[Union[str, Image, Part]], |
| 862 | + * List[Content] |
| 863 | +
|
| 864 | + Returns: |
| 865 | + And awaitable for a CountTokensResponse object that has the following attributes: |
| 866 | + total_tokens: The total number of tokens counted across all instances from the request. |
| 867 | + total_billable_characters: The total number of billable characters counted across all instances from the request. |
| 868 | + """ |
| 869 | + return await self._llm_utility_async_client.compute_tokens( |
| 870 | + request=gapic_llm_utility_service_types.ComputeTokensRequest( |
| 871 | + endpoint=self._prediction_resource_name, |
| 872 | + model=self._prediction_resource_name, |
| 873 | + contents=self._prepare_request(contents=contents).contents, |
| 874 | + ) |
| 875 | + ) |
| 876 | + |
793 | 877 | def start_chat( |
794 | 878 | self, |
795 | 879 | *, |
|
0 commit comments