community[patch]: Support Streaming in Azure Machine Learning (#18246)

- [x] **PR title**: "community: Support streaming in Azure ML and few
naming changes"

- [x] **PR message**:
- **Description:** Added support for streaming for azureml_endpoint.
Also, renamed and AzureMLEndpointApiType.realtime to
AzureMLEndpointApiType.dedicated. Also, added new classes
CustomOpenAIChatContentFormatter and CustomOpenAIContentFormatter and
updated the classes LlamaChatContentFormatter and LlamaContentFormatter
to now show a deprecated warning message when instantiated.

---------

Co-authored-by: Sachin Paryani <saparan@microsoft.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Sachin Paryani
2024-03-28 16:38:20 -07:00
committed by GitHub
parent ecb11a4a32
commit 25c9f3d1d1
6 changed files with 285 additions and 76 deletions

View File

@@ -1,16 +1,37 @@
import json
from typing import Any, Dict, List, Optional, cast
import warnings
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
cast,
)
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.azureml_endpoint import (
AzureMLBaseEndpoint,
@@ -25,12 +46,12 @@ class LlamaContentFormatter(ContentFormatterBase):
def __init__(self) -> None:
raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use "
"`LlamaChatContentFormatter` instead."
"`CustomOpenAIContentFormatter` instead."
)
class LlamaChatContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`."""
class CustomOpenAIChatContentFormatter(ContentFormatterBase):
"""Chat Content formatter for models with OpenAI like API scheme."""
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
@@ -55,7 +76,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
}
elif (
isinstance(message, ChatMessage)
and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES
and message.role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES
):
return {
"role": message.role,
@@ -63,7 +84,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
}
else:
supported = ",".join(
[role for role in LlamaChatContentFormatter.SUPPORTED_ROLES]
[role for role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES]
)
raise ValueError(
f"""Received unsupported role.
@@ -72,7 +93,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
def format_messages_request_payload(
self,
@@ -82,10 +103,13 @@ class LlamaChatContentFormatter(ContentFormatterBase):
) -> bytes:
"""Formats the request according to the chosen api"""
chat_messages = [
LlamaChatContentFormatter._convert_message_to_dict(message)
CustomOpenAIChatContentFormatter._convert_message_to_dict(message)
for message in messages
]
if api_type == AzureMLEndpointApiType.realtime:
if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
request_payload = json.dumps(
{
"input_data": {
@@ -105,10 +129,13 @@ class LlamaChatContentFormatter(ContentFormatterBase):
def format_response_payload(
self,
output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> ChatGeneration:
"""Formats response"""
if api_type == AzureMLEndpointApiType.realtime:
if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
try:
choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e:
@@ -143,6 +170,20 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class LlamaChatContentFormatter(CustomOpenAIChatContentFormatter):
"""Deprecated: Kept for backwards compatibility
Chat Content formatter for Llama."""
def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`LlamaChatContentFormatter` will be deprecated in the future.
Please use `CustomOpenAIChatContentFormatter` instead.
"""
)
class MistralChatContentFormatter(LlamaChatContentFormatter):
"""Content formatter for `Mistral`."""
@@ -187,8 +228,8 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
Example:
.. code-block:: python
azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime,
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions",
endpoint_api_type=AzureMLApiType.serverless,
endpoint_api_key="my-api-key",
content_formatter=chat_content_formatter,
)
@@ -239,3 +280,143 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
response_payload, self.endpoint_api_type
)
return ChatResult(generations=[generations])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
import openai
params = {}
client_params = {
"api_key": self.endpoint_api_key.get_secret_value(),
"base_url": self.endpoint_url,
"timeout": timeout,
"default_headers": None,
"default_query": None,
"http_client": None,
}
client = openai.OpenAI(**client_params)
message_dicts = [
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
for m in messages
]
params = {"stream": True, "stop": stop, "model": None, **kwargs}
default_chunk_class = AIMessageChunk
for chunk in client.chat.completions.create(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
import openai
params = {}
client_params = {
"api_key": self.endpoint_api_key.get_secret_value(),
"base_url": self.endpoint_url,
"timeout": timeout,
"default_headers": None,
"default_query": None,
"http_client": None,
}
async_client = openai.AsyncOpenAI(**client_params)
message_dicts = [
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
for m in messages
]
params = {"stream": True, "stop": stop, "model": None, **kwargs}
default_chunk_class = AIMessageChunk
async for chunk in await async_client.chat.completions.create(
messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)