Skip to content

Commit cfe0cc6

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Added tokenization support via GenerativeModel.compute_tokens
Usage: ``` model = GenerativeModel("gemini-1.0-pro") tokens = model.compute_tokens("Hello world") ``` PiperOrigin-RevId: 651536821
1 parent c5a3535 commit cfe0cc6

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

tests/system/vertexai/test_generative_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,15 @@ def test_additional_request_metadata(self):
479479
generation_config=generative_models.GenerationConfig(temperature=0),
480480
)
481481
assert response
482+
483+
def test_compute_tokens_from_text(self):
484+
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
485+
response = model.compute_tokens(["Why is sky blue?", "Explain it like I'm 5."])
486+
assert len(response.tokens_info) == 2
487+
for token_info in response.tokens_info:
488+
assert token_info.tokens
489+
assert token_info.token_ids
490+
assert len(token_info.token_ids) == len(token_info.tokens)
491+
assert token_info.role
492+
# Lightly validate that the tokens are not Base64 encoded
493+
assert b"=" not in token_info.tokens

vertexai/generative_models/_generative_models.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
from google.cloud.aiplatform import utils as aiplatform_utils
3939
from google.cloud.aiplatform_v1beta1 import types as aiplatform_types
4040
from google.cloud.aiplatform_v1beta1.services import prediction_service
41+
from google.cloud.aiplatform_v1beta1.services import llm_utility_service
4142
from google.cloud.aiplatform_v1beta1.types import (
4243
content as gapic_content_types,
4344
)
4445
from google.cloud.aiplatform_v1beta1.types import (
4546
prediction_service as gapic_prediction_service_types,
47+
llm_utility_service as gapic_llm_utility_service_types,
4648
)
4749
from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types
4850
from google.protobuf import json_format
@@ -385,6 +387,34 @@ def _prediction_async_client(
385387
)
386388
return self._prediction_async_client_value
387389

390+
@property
391+
def _llm_utility_client(self) -> llm_utility_service.LlmUtilityServiceClient:
392+
# Switch to @functools.cached_property once its available.
393+
if not getattr(self, "_llm_utility_client_value", None):
394+
self._llm_utility_client_value = (
395+
aiplatform_initializer.global_config.create_client(
396+
client_class=llm_utility_service.LlmUtilityServiceClient,
397+
location_override=self._location,
398+
prediction_client=True,
399+
)
400+
)
401+
return self._llm_utility_client_value
402+
403+
@property
404+
def _llm_utility_async_client(
405+
self,
406+
) -> llm_utility_service.LlmUtilityServiceAsyncClient:
407+
# Switch to @functools.cached_property once its available.
408+
if not getattr(self, "_llm_utility_async_client_value", None):
409+
self._llm_utility_async_client_value = (
410+
aiplatform_initializer.global_config.create_client(
411+
client_class=llm_utility_service.LlmUtilityServiceAsyncClient,
412+
location_override=self._location,
413+
prediction_client=True,
414+
)
415+
)
416+
return self._llm_utility_async_client_value
417+
388418
def _prepare_request(
389419
self,
390420
contents: ContentsType,
@@ -790,6 +820,60 @@ async def count_tokens_async(
790820
)
791821
)
792822

823+
def compute_tokens(
824+
self, contents: ContentsType
825+
) -> gapic_llm_utility_service_types.ComputeTokensResponse:
826+
"""Counts tokens.
827+
828+
Args:
829+
contents: Contents to send to the model.
830+
Supports either a list of Content objects (passing a multi-turn conversation)
831+
or a value that can be converted to a single Content object (passing a single message).
832+
Supports
833+
* str, Image, Part,
834+
* List[Union[str, Image, Part]],
835+
* List[Content]
836+
837+
Returns:
838+
A CountTokensResponse object that has the following attributes:
839+
total_tokens: The total number of tokens counted across all instances from the request.
840+
total_billable_characters: The total number of billable characters counted across all instances from the request.
841+
"""
842+
return self._llm_utility_client.compute_tokens(
843+
request=gapic_llm_utility_service_types.ComputeTokensRequest(
844+
endpoint=self._prediction_resource_name,
845+
model=self._prediction_resource_name,
846+
contents=self._prepare_request(contents=contents).contents,
847+
)
848+
)
849+
850+
async def compute_tokens_async(
851+
self, contents: ContentsType
852+
) -> gapic_llm_utility_service_types.ComputeTokensResponse:
853+
"""Counts tokens asynchronously.
854+
855+
Args:
856+
contents: Contents to send to the model.
857+
Supports either a list of Content objects (passing a multi-turn conversation)
858+
or a value that can be converted to a single Content object (passing a single message).
859+
Supports
860+
* str, Image, Part,
861+
* List[Union[str, Image, Part]],
862+
* List[Content]
863+
864+
Returns:
865+
And awaitable for a CountTokensResponse object that has the following attributes:
866+
total_tokens: The total number of tokens counted across all instances from the request.
867+
total_billable_characters: The total number of billable characters counted across all instances from the request.
868+
"""
869+
return await self._llm_utility_async_client.compute_tokens(
870+
request=gapic_llm_utility_service_types.ComputeTokensRequest(
871+
endpoint=self._prediction_resource_name,
872+
model=self._prediction_resource_name,
873+
contents=self._prepare_request(contents=contents).contents,
874+
)
875+
)
876+
793877
def start_chat(
794878
self,
795879
*,

0 commit comments

Comments
 (0)