mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 23:12:38 +00:00
community[patch]: Support Streaming in Azure Machine Learning (#18246)
- [x] **PR title**: "community: Support streaming in Azure ML and few naming changes" - [x] **PR message**: - **Description:** Added support for streaming for azureml_endpoint. Also, renamed and AzureMLEndpointApiType.realtime to AzureMLEndpointApiType.dedicated. Also, added new classes CustomOpenAIChatContentFormatter and CustomOpenAIContentFormatter and updated the classes LlamaChatContentFormatter and LlamaContentFormatter to now show a deprecated warning message when instantiated. --------- Co-authored-by: Sachin Paryani <saparan@microsoft.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1,16 +1,37 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain_community.llms.azureml_endpoint import (
|
||||
AzureMLBaseEndpoint,
|
||||
@@ -25,12 +46,12 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
def __init__(self) -> None:
|
||||
raise TypeError(
|
||||
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
||||
"`LlamaChatContentFormatter` instead."
|
||||
"`CustomOpenAIContentFormatter` instead."
|
||||
)
|
||||
|
||||
|
||||
class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for `LLaMA`."""
|
||||
class CustomOpenAIChatContentFormatter(ContentFormatterBase):
|
||||
"""Chat Content formatter for models with OpenAI like API scheme."""
|
||||
|
||||
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
||||
|
||||
@@ -55,7 +76,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
}
|
||||
elif (
|
||||
isinstance(message, ChatMessage)
|
||||
and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES
|
||||
and message.role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES
|
||||
):
|
||||
return {
|
||||
"role": message.role,
|
||||
@@ -63,7 +84,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
}
|
||||
else:
|
||||
supported = ",".join(
|
||||
[role for role in LlamaChatContentFormatter.SUPPORTED_ROLES]
|
||||
[role for role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES]
|
||||
)
|
||||
raise ValueError(
|
||||
f"""Received unsupported role.
|
||||
@@ -72,7 +93,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_messages_request_payload(
|
||||
self,
|
||||
@@ -82,10 +103,13 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
chat_messages = [
|
||||
LlamaChatContentFormatter._convert_message_to_dict(message)
|
||||
CustomOpenAIChatContentFormatter._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
if api_type in [
|
||||
AzureMLEndpointApiType.dedicated,
|
||||
AzureMLEndpointApiType.realtime,
|
||||
]:
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
@@ -105,10 +129,13 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
def format_response_payload(
|
||||
self,
|
||||
output: bytes,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
|
||||
) -> ChatGeneration:
|
||||
"""Formats response"""
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
if api_type in [
|
||||
AzureMLEndpointApiType.dedicated,
|
||||
AzureMLEndpointApiType.realtime,
|
||||
]:
|
||||
try:
|
||||
choice = json.loads(output)["output"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
@@ -143,6 +170,20 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||
|
||||
|
||||
class LlamaChatContentFormatter(CustomOpenAIChatContentFormatter):
|
||||
"""Deprecated: Kept for backwards compatibility
|
||||
|
||||
Chat Content formatter for Llama."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"""`LlamaChatContentFormatter` will be deprecated in the future.
|
||||
Please use `CustomOpenAIChatContentFormatter` instead.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class MistralChatContentFormatter(LlamaChatContentFormatter):
|
||||
"""Content formatter for `Mistral`."""
|
||||
|
||||
@@ -187,8 +228,8 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_type=AzureMLApiType.realtime,
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions",
|
||||
endpoint_api_type=AzureMLApiType.serverless,
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=chat_content_formatter,
|
||||
)
|
||||
@@ -239,3 +280,143 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
||||
response_payload, self.endpoint_api_type
|
||||
)
|
||||
return ChatResult(generations=[generations])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
|
||||
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
|
||||
|
||||
import openai
|
||||
|
||||
params = {}
|
||||
client_params = {
|
||||
"api_key": self.endpoint_api_key.get_secret_value(),
|
||||
"base_url": self.endpoint_url,
|
||||
"timeout": timeout,
|
||||
"default_headers": None,
|
||||
"default_query": None,
|
||||
"http_client": None,
|
||||
}
|
||||
|
||||
client = openai.OpenAI(**client_params)
|
||||
message_dicts = [
|
||||
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
||||
for m in messages
|
||||
]
|
||||
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in client.chat.completions.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
|
||||
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
|
||||
|
||||
import openai
|
||||
|
||||
params = {}
|
||||
client_params = {
|
||||
"api_key": self.endpoint_api_key.get_secret_value(),
|
||||
"base_url": self.endpoint_url,
|
||||
"timeout": timeout,
|
||||
"default_headers": None,
|
||||
"default_query": None,
|
||||
"http_client": None,
|
||||
}
|
||||
|
||||
async_client = openai.AsyncOpenAI(**client_params)
|
||||
message_dicts = [
|
||||
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
||||
for m in messages
|
||||
]
|
||||
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await async_client.chat.completions.create(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||
)
|
||||
yield chunk
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
if _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
Reference in New Issue
Block a user