Skip to content
22 changes: 7 additions & 15 deletions ads/llm/langchain/plugins/llm_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator
from ads.common.auth import default_signer
from ads.llm.langchain.plugins.base import BaseLLM


Expand Down Expand Up @@ -38,14 +36,6 @@ def _identifying_params(self) -> Mapping[str, Any]:
**self._default_params,
}

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Dont do anything if client provided externally."""

_signer = values.get("auth", default_signer()["signer"])
values["signer"] = _signer
return values

def _call(
self,
prompt: str,
Expand Down Expand Up @@ -86,14 +76,14 @@ def send_request(
self,
data,
endpoint: str,
header: dict = {},
header: dict = None,
**kwargs,
) -> Dict:
"""Sends request to the model deployment endpoint.

Parameters
----------
data (Json serializablype):
data (Json serializable):
data need to be sent to the endpoint.
endpoint (str):
The model HTTP endpoint.
Expand All @@ -109,15 +99,17 @@ def send_request(

Returns
-------
A JSON representive of a requests.Response object.
A JSON representation of a requests.Response object.
"""
if not header:
header = {}
header["Content-Type"] = (
header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON)
or DEFAULT_CONTENT_TYPE_JSON
)
request_kwargs = {"json": data}
request_kwargs["headers"] = header
request_kwargs["auth"] = self.signer
request_kwargs["auth"] = self.auth.get("signer")

try:
response = requests.post(endpoint, **request_kwargs, **kwargs)
Expand Down Expand Up @@ -158,7 +150,7 @@ class ModelDeploymentTGI(OCIModelDeployment):

watermark = True

return_full_text = True
return_full_text = False

@property
def _llm_type(self) -> str:
Expand Down