Skip to content

Commit f0e97b0

Browse files
authored
Merge pull request NVIDIA-NeMo#654
Fix/nvidia ai endpoints streaming
2 parents 8a8921c + 0c7e78f commit f0e97b0

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type
18+
19+
import pkg_resources
20+
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
21+
from langchain_core.language_models.chat_models import generate_from_stream
22+
from langchain_core.messages import (
23+
AIMessageChunk,
24+
BaseMessage,
25+
BaseMessageChunk,
26+
ChatMessage,
27+
ChatMessageChunk,
28+
FunctionMessageChunk,
29+
HumanMessageChunk,
30+
SystemMessageChunk,
31+
ToolMessageChunk,
32+
)
33+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
34+
from langchain_core.pydantic_v1 import Field
35+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
36+
from packaging import version
37+
38+
log = logging.getLogger(__name__)
39+
40+
41+
def _convert_delta_to_message_chunk(
42+
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
43+
) -> BaseMessageChunk:
44+
role = _dict.get("role")
45+
content = _dict.get("content") or ""
46+
additional_kwargs: Dict = {}
47+
if _dict.get("function_call"):
48+
function_call = dict(_dict["function_call"])
49+
if "name" in function_call and function_call["name"] is None:
50+
function_call["name"] = ""
51+
additional_kwargs["function_call"] = function_call
52+
if _dict.get("tool_calls"):
53+
additional_kwargs["tool_calls"] = _dict["tool_calls"]
54+
55+
if role == "user" or default_class == HumanMessageChunk:
56+
return HumanMessageChunk(content=content)
57+
elif role == "assistant" or default_class == AIMessageChunk:
58+
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
59+
elif role == "system" or default_class == SystemMessageChunk:
60+
return SystemMessageChunk(content=content)
61+
elif role == "function" or default_class == FunctionMessageChunk:
62+
return FunctionMessageChunk(content=content, name=_dict["name"])
63+
elif role == "tool" or default_class == ToolMessageChunk:
64+
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
65+
elif role or default_class == ChatMessageChunk:
66+
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
67+
else:
68+
return default_class(content=content) # type: ignore[call-arg]
69+
70+
71+
class PatchedChatNVIDIAV1(ChatNVIDIA):
72+
streaming: bool = Field(
73+
default=False, description="Whether to use streaming or not"
74+
)
75+
76+
def _generate(
77+
self,
78+
messages: List[BaseMessage],
79+
stop: Optional[List[str]] = None,
80+
run_manager: Optional[CallbackManagerForLLMRun] = None,
81+
stream: Optional[bool] = None,
82+
**kwargs: Any,
83+
) -> ChatResult:
84+
should_stream = stream if stream is not None else self.streaming
85+
if should_stream:
86+
stream_iter = self._stream(
87+
messages, stop=stop, run_manager=run_manager, **kwargs
88+
)
89+
return generate_from_stream(stream_iter)
90+
inputs = self._custom_preprocess(messages)
91+
payload = self._get_payload(inputs=inputs, stop=stop, stream=False, **kwargs)
92+
response = self._client.client.get_req(payload=payload)
93+
responses, _ = self._client.client.postprocess(response)
94+
self._set_callback_out(responses, run_manager)
95+
message = ChatMessage(**self._custom_postprocess(responses))
96+
generation = ChatGeneration(message=message)
97+
return ChatResult(generations=[generation], llm_output=responses)
98+
99+
def _stream(
100+
self,
101+
messages: List[BaseMessage],
102+
stop: Optional[Sequence[str]] = None,
103+
run_manager: Optional[CallbackManagerForLLMRun] = None,
104+
**kwargs: Any,
105+
) -> Iterator[ChatGenerationChunk]:
106+
"""Allows streaming to model!"""
107+
inputs = self._custom_preprocess(messages)
108+
payload = self._get_payload(inputs=inputs, stop=stop, stream=True, **kwargs)
109+
default_chunk_class = AIMessageChunk
110+
for response in self._client.client.get_req_stream(payload=payload):
111+
self._set_callback_out(response, run_manager)
112+
chunk = _convert_delta_to_message_chunk(response, default_chunk_class)
113+
default_chunk_class = chunk.__class__
114+
cg_chunk = ChatGenerationChunk(message=chunk)
115+
if run_manager:
116+
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
117+
yield cg_chunk
118+
119+
120+
class PatchedChatNVIDIAV2(ChatNVIDIA):
121+
streaming: bool = Field(
122+
default=False, description="Whether to use streaming or not"
123+
)
124+
125+
def _generate(
126+
self,
127+
messages: List[BaseMessage],
128+
stop: Optional[List[str]] = None,
129+
run_manager: Optional[CallbackManagerForLLMRun] = None,
130+
stream: Optional[bool] = None,
131+
**kwargs: Any,
132+
) -> ChatResult:
133+
should_stream = stream if stream is not None else self.streaming
134+
if should_stream:
135+
stream_iter = self._stream(
136+
messages, stop=stop, run_manager=run_manager, **kwargs
137+
)
138+
return generate_from_stream(stream_iter)
139+
inputs = [
140+
_nv_vlm_adjust_input(message)
141+
for message in [convert_message_to_dict(message) for message in messages]
142+
]
143+
payload = self._get_payload(inputs=inputs, stop=stop, stream=False, **kwargs)
144+
response = self._client.client.get_req(payload=payload)
145+
responses, _ = self._client.client.postprocess(response)
146+
self._set_callback_out(responses, run_manager)
147+
parsed_response = self._custom_postprocess(responses, streaming=False)
148+
# for pre 0.2 compatibility w/ ChatMessage
149+
# ChatMessage had a role property that was not present in AIMessage
150+
parsed_response.update({"role": "assistant"})
151+
generation = ChatGeneration(message=AIMessage(**parsed_response))
152+
return ChatResult(generations=[generation], llm_output=responses)
153+
154+
155+
class ChatNVIDIAFactory:
156+
RANGE1 = (version.parse("0.1.0"), version.parse("0.2.0"))
157+
RANGE2 = (version.parse("0.2.0"), version.parse("0.3.0"))
158+
159+
@staticmethod
160+
def get_package_version(package_name):
161+
return version.parse(pkg_resources.get_distribution(package_name).version)
162+
163+
@staticmethod
164+
def is_version_in_range(version, range):
165+
return range[0] <= version < range[1]
166+
167+
@classmethod
168+
def create(cls):
169+
current_version = cls.get_package_version("langchain_nvidia_ai_endpoints")
170+
171+
if cls.is_version_in_range(current_version, cls.RANGE1):
172+
log.debug(
173+
f"Using pathed version of ChatNVIDIA for version {current_version}"
174+
)
175+
return PatchedChatNVIDIAV1
176+
elif cls.is_version_in_range(current_version, cls.RANGE2):
177+
log.debug(
178+
f"Using pathed version of ChatNVIDIA for version {current_version}"
179+
)
180+
from langchain_community.adapters.openai import convert_message_to_dict
181+
from langchain_nvidia_ai_endpoints.chat_models import _nv_vlm_adjust_input
182+
183+
return PatchedChatNVIDIAV2
184+
else:
185+
return ChatNVIDIA
186+
187+
188+
ChatNVIDIA = ChatNVIDIAFactory.create()
189+
190+
191+
__all__ = ["ChatNVIDIA"]

nemoguardrails/llm/providers/providers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
242242
try:
243243
from langchain_nvidia_ai_endpoints import ChatNVIDIA
244244

245+
from ._langchain_nvidia_ai_endpoints_patch import ChatNVIDIA
246+
245247
return ChatNVIDIA
246248
except ImportError:
247249
raise ImportError(

0 commit comments

Comments
 (0)