Skip to content
Prev Previous commit
Next Next commit
Update GenerativeAI model.
  • Loading branch information
qiuosier committed Oct 26, 2023
commit b36aac98d96121644d9e97a46dfafee9ac6bfdda
143 changes: 65 additions & 78 deletions ads/llm/langchain/plugins/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,28 @@
# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import oci
import requests
import logging

from enum import Enum
from typing import Any, Mapping, Dict, List, Optional
from ads.common.auth import default_signer

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import root_validator, Field, Extra
from pydantic import BaseModel, root_validator

from ads.common.auth import default_signer
from ads.config import COMPARTMENT_OCID

try:
from oci.generative_ai import GenerativeAiClient, models
except ImportError as e:
print("Pip install `oci` with correct version")
pass

logger = logging.getLogger(__name__)


class StrEnum(str, Enum):
pass


# Move to constant.py
class TASK:
class Task(StrEnum):
TEXT_GENERATION = "text_generation"
SUMMARY_TEXT = "summary_text"

Expand Down Expand Up @@ -56,9 +55,6 @@ class OCIGenerativeAIModelOptions:
COHERE_COMMAND_LIGHT = "cohere.command-light"


DEFAULT_SERVICE_ENDPOINT = (
"https://generativeai.aiservice.us-chicago-1.oci.oraclecloud.com"
)
DEFAULT_TIME_OUT = 300
DEFAULT_CONTENT_TYPE_JSON = "application/json"

Expand Down Expand Up @@ -87,7 +83,38 @@ class BaseLLM(LLM):
"""Stop words to use when generating. Model output is cut off at the first occurrence of any of these substrings."""


class GenerativeAI(BaseLLM):
class GenerativeAiClientModel(BaseModel):
client: Any #: :meta private:
"""OCI GenerativeAiClient."""

compartment_id: str
"""Compartment ID of the caller."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""
try:
# Import the GenerativeAIClient here so that there will be no error when user import ads.llm
# and the install OCI SDK does not support generative AI service yet.
from oci.generative_ai import GenerativeAiClient
except ImportError as ex:
raise ImportError(
"Could not import GenerativeAIClient from oci. "
"The OCI SDK installed does not support generative AI service."
) from ex
# Initialize client only if user does not pass in client.
# Users may choose to initialize the OCI client by themselves and pass it into this model.
if not values.get("client"):
client_kwargs = values["client_kwargs"] or {}
values["client"] = GenerativeAiClient(**default_signer(), **client_kwargs)
# Set default compartment ID
if "compartment_id" not in values and COMPARTMENT_OCID:
values["compartment_id"] = COMPARTMENT_OCID

return values


class GenerativeAI(GenerativeAiClientModel, BaseLLM):
"""GenerativeAI Service.

To use, you should have the ``oci`` python package installed.
Expand All @@ -103,8 +130,8 @@ class GenerativeAI(BaseLLM):

"""

client: Any #: :meta private:
"""OCI GenerativeAiClient."""
task: Task = Task.TEXT_GENERATION
"""Indicates the task."""

model: Optional[str] = OCIGenerativeAIModelOptions.COHERE_COMMAND
"""Model name to use."""
Expand All @@ -130,45 +157,12 @@ class GenerativeAI(BaseLLM):
additional_command: str = ""
"""A free-form instruction for modifying how the summaries get generated. """

endpoint_kwargs: Dict[str, Any] = Field(default_factory=dict)
endpoint_kwargs: Dict[str, Any] = {}
"""Optional attributes passed to the generate_text/summarize_text function."""

client_kwargs: Dict[str, Any] = {}
"""Holds any client parametes for creating GenerativeAiClient"""

service_endpoint: str = DEFAULT_SERVICE_ENDPOINT
"""The name of the service endpoint from the OCI GenertiveAI Service."""

compartment_id: str = None
"""Compartment ID of the caller."""

task: str = TASK.TEXT_GENERATION
"""Indicates the task."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""
if values.get("client") is not None:
return values

_auth = values.get("auth", default_signer())
_client_kwargs = values["client_kwargs"] or {}
_service_endpoint = _client_kwargs.get(
"service_endpoint",
values.get("service_endpoint", None) or cls.service_endpoint,
)
_client_kwargs["service_endpoint"] = _service_endpoint
try:
import oci

values["client"] = GenerativeAiClient(**_auth, **_client_kwargs)
except ImportError:
raise ImportError(
"Could not import oci python package. "
"Please install it with `pip install oci`."
)
return values

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
Expand All @@ -190,21 +184,23 @@ def _llm_type(self) -> str:
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OCIGenerativeAI API."""
from oci.generative_ai.models import OnDemandServingMode

return (
{
"serving_mode": OnDemandServingMode(model_id=self.model),
"compartment_id": self.compartment_id,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_k": self.k,
"top_p": self.p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"truncate": self.truncate,
"serving_mode": models.OnDemandServingMode(model_id=self.model),
}
if self.task == "text_generation"
if self.task == Task.TEXT_GENERATION
else {
"serving_mode": models.OnDemandServingMode(model_id=self.model),
"serving_mode": OnDemandServingMode(model_id=self.model),
"compartment_id": self.compartment_id,
"temperature": self.temperature,
"length": self.length,
Expand All @@ -216,7 +212,7 @@ def _default_params(self) -> Dict[str, Any]:

def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
params = self._default_params
if self.task == TASK.SUMMARY_TEXT:
if self.task == Task.SUMMARY_TEXT:
return {**params}

if self.stop is not None and stop is not None:
Expand Down Expand Up @@ -255,23 +251,16 @@ def _call(

params = self._invocation_params(stop, **kwargs)

try:
response = (
self.completion_with_retry(prompts=[prompt], **params)
if self.task == TASK.TEXT_GENERATION
else self.completion_with_retry(input=prompt, **params)
)
except Exception as ex:
logger.error(
"Error occur when invoking oci service api."
f"DEBUG INTO: task={self.task}, params={params}, prompt={prompt}"
)
raise
response = (
self.completion_with_retry(prompts=[prompt], **params)
if self.task == Task.TEXT_GENERATION
else self.completion_with_retry(input=prompt, **params)
)

return self._process_response(response, params.get("num_generations", 1))

def _process_response(self, response: Any, num_generations: int = 1) -> str:
if self.task == TASK.SUMMARY_TEXT:
if self.task == Task.SUMMARY_TEXT:
return response.data.summary

return (
Expand All @@ -281,19 +270,17 @@ def _process_response(self, response: Any, num_generations: int = 1) -> str:
)

def completion_with_retry(self, **kwargs: Any) -> Any:
_model_kwargs = {**kwargs}
_endpoint_kwargs = self.endpoint_kwargs or {}
# TODO: Add retry logic for OCI
from oci.generative_ai.models import GenerateTextDetails, SummarizeTextDetails

if self.task == TASK.TEXT_GENERATION:
if self.task == Task.TEXT_GENERATION:
return self.client.generate_text(
models.GenerateTextDetails(**_model_kwargs), **_endpoint_kwargs
GenerateTextDetails(**kwargs), **self.endpoint_kwargs
)
elif self.task == TASK.SUMMARY_TEXT:
else:
return self.client.summarize_text(
models.SummarizeTextDetails(**_model_kwargs), **_endpoint_kwargs
SummarizeTextDetails(**kwargs), **self.endpoint_kwargs
)
else:
raise ValueError("Unsupported tasks.")

def batch_completion(
self,
Expand Down Expand Up @@ -332,9 +319,9 @@ def batch_completion(
responses = gen_ai.batch_completion("Tell me a joke.", num_generations=5)

"""
if self.task == TASK.SUMMARY_TEXT:
if self.task == Task.SUMMARY_TEXT:
raise NotImplementedError(
f"task={TASK.SUMMARY_TEXT} does not support batch_completion. "
f"task={Task.SUMMARY_TEXT} does not support batch_completion. "
)

return self._call(
Expand Down