Add LLaMa Formatter and AzureML Chat Endpoint (#8382)

## Description

Microsoft and Meta recently [announced their
collaboration](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/)
on LLaMa2. This PR extends the current LLM wrapper and introduces a new
Chat Model wrapper for AzureML to support LLaMa2.

## Dependencies

No dependencies added :)

## Twitter Handles

[@matthew_d13](https://twitter.com/matthew_d13)
[@prakhar_in](https://twitter.com/prakhar_in)

maintainers - @hwchase17, @baskaryan
This commit is contained in:
Matthew DeGuzman
2023-07-31 16:26:25 -07:00
committed by GitHub
parent 1ab773c742
commit 844eca98d5
6 changed files with 418 additions and 44 deletions

View File

@@ -0,0 +1,151 @@
import json
from typing import Any, Dict, List, Optional
from pydantic import validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for LLaMa"""
SUPPORTED_ROLES = ["user", "assistant", "system"]
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> Dict:
"""Converts message to a dict according to role"""
if isinstance(message, HumanMessage):
return {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
return {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
return {"role": "system", "content": message.content}
elif (
isinstance(message, ChatMessage)
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
):
return {"role": message.role, "content": message.content}
else:
supported = ",".join(
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
)
raise ValueError(
f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}"""
)
def _format_request_payload(
self, messages: List[BaseMessage], model_kwargs: Dict
) -> bytes:
chat_messages = [
LlamaContentFormatter._convert_message_to_dict(message)
for message in messages
]
prompt = json.dumps(
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
)
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
"""Formats the request according the the chosen api"""
return str.encode(prompt)
def format_response_payload(self, output: bytes) -> str:
"""Formats response"""
return json.loads(output)["output"]
class AzureMLChatOnlineEndpoint(SimpleChatModel):
"""Azure ML Chat Online Endpoint models.
Example:
.. code-block:: python
azure_chat = AzureMLChatOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_key="my-api-key",
content_formatter=content_formatter,
)
"""
endpoint_url: str = ""
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_key: str = ""
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_API_KEY`."""
http_client: Any = None #: :meta private:
content_formatter: Any = None
"""The content formatter that provides an input and output
transform function to handle formats between the LLM and
the endpoint"""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""
@validator("http_client", always=True, allow_reuse=True)
@classmethod
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""Validate that api key and python package exists in environment."""
endpoint_key = get_from_dict_or_env(
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
)
endpoint_url = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key)
return http_client
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "azureml_chat_endpoint"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to an AzureML Managed Online endpoint.
Args:
messages: The messages in the conversation with the chat model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = azureml_model("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
request_payload = self.content_formatter._format_request_payload(
messages, _model_kwargs
)
response_payload = self.http_client.call(request_payload, **kwargs)
generated_text = self.content_formatter.format_response_payload(
response_payload
)
return generated_text

View File

@@ -1,5 +1,6 @@
import json
import urllib.request
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Mapping, Optional
@@ -14,16 +15,19 @@ class AzureMLEndpointClient(object):
"""AzureML Managed Endpoint client."""
def __init__(
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = ""
) -> None:
"""Initialize the class."""
if not endpoint_api_key:
raise ValueError("A key should be provided to invoke the endpoint")
if not endpoint_api_key or not endpoint_url:
raise ValueError(
"""A key/token and REST endpoint should
be provided to invoke the endpoint"""
)
self.endpoint_url = endpoint_url
self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name
def call(self, body: bytes) -> bytes:
def call(self, body: bytes, **kwargs: Any) -> bytes:
"""call."""
# The azureml-model-deployment header will force the request to go to a
@@ -32,11 +36,12 @@ class AzureMLEndpointClient(object):
headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self.endpoint_api_key),
"azureml-model-deployment": self.deployment_name,
}
if self.deployment_name != "":
headers["azureml-model-deployment"] = self.deployment_name
req = urllib.request.Request(self.endpoint_url, body, headers)
response = urllib.request.urlopen(req, timeout=50)
response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50))
result = response.read()
return result
@@ -75,7 +80,26 @@ class ContentFormatterBase:
"""The MIME type of the input data passed to the endpoint"""
accepts: Optional[str] = "application/json"
"""The MIME type of the response data returned form the endpoint"""
"""The MIME type of the response data returned from the endpoint"""
@staticmethod
def escape_special_characters(prompt: str) -> str:
"""Escapes any special characters in `prompt`"""
escape_map = {
"\\": "\\\\",
'"': '\\"',
"\b": "\\b",
"\f": "\\f",
"\n": "\\n",
"\r": "\\r",
"\t": "\\t",
}
# Replace each occurrence of the specified characters with escaped versions
for escape_sequence, escaped_sequence in escape_map.items():
prompt = prompt.replace(escape_sequence, escaped_sequence)
return prompt
@abstractmethod
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
@@ -92,44 +116,86 @@ class ContentFormatterBase:
"""
class OSSContentFormatter(ContentFormatterBase):
"""Content handler for LLMs from the OSS catalog."""
class GPT2ContentFormatter(ContentFormatterBase):
"""Content handler for GPT2"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{"inputs": {"input_string": [prompt]}, "parameters": model_kwargs}
prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps(
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
)
return str.encode(input_str)
return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0]["0"]
return json.loads(output)[0]["0"]
class OSSContentFormatter(GPT2ContentFormatter):
"""Deprecated: Kept for backwards compatibility
Content handler for LLMs from the OSS catalog."""
content_formatter: Any = None
def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`OSSContentFormatter` will be deprecated in the future.
Please use `GPT2ContentFormatter` instead.
"""
)
class HFContentFormatter(ContentFormatterBase):
"""Content handler for LLMs from the HuggingFace catalog."""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": [prompt], "parameters": model_kwargs})
return str.encode(input_str)
ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps(
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
)
return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0][0]["generated_text"]
return json.loads(output)[0]["generated_text"]
class DollyContentFormatter(ContentFormatterBase):
"""Content handler for the Dolly-v2-12b model"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{"input_data": {"input_string": [prompt]}, "parameters": model_kwargs}
prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps(
{
"input_data": {"input_string": [f'"{prompt}"']},
"parameters": model_kwargs,
}
)
return str.encode(input_str)
return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0]
return json.loads(output)[0]
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for LLaMa"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
"""Formats the request according the the chosen api"""
prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps(
{
"input_data": {
"input_string": [f'"{prompt}"'],
"parameters": model_kwargs,
}
}
)
return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str:
"""Formats response"""
return json.loads(output)[0]["0"]
class AzureMLOnlineEndpoint(LLM, BaseModel):
@@ -138,10 +204,9 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
Example:
.. code-block:: python
azure_llm = AzureMLModel(
azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_key="my-api-key",
deployment_name="my-deployment-name",
content_formatter=content_formatter,
)
""" # noqa: E501
@@ -155,8 +220,8 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
env var `AZUREML_ENDPOINT_API_KEY`."""
deployment_name: str = ""
"""Deployment Name for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_DEPLOYMENT_NAME`."""
"""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
http_client: Any = None #: :meta private:
@@ -179,7 +244,7 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
deployment_name = get_from_dict_or_env(
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME"
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
)
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
return http_client
@@ -203,7 +268,7 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
**kwargs: Any,
) -> str:
"""Call out to an AzureML Managed Online endpoint.
Args:
@@ -217,7 +282,11 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
"""
_model_kwargs = self.model_kwargs or {}
body = self.content_formatter.format_request_payload(prompt, _model_kwargs)
endpoint_response = self.http_client.call(body)
response = self.content_formatter.format_response_payload(endpoint_response)
return response
request_payload = self.content_formatter.format_request_payload(
prompt, _model_kwargs
)
response_payload = self.http_client.call(request_payload, **kwargs)
generated_text = self.content_formatter.format_response_payload(
response_payload
)
return generated_text

View File

@@ -0,0 +1,58 @@
"""Test AzureML Chat Endpoint wrapper."""
from langchain.chat_models.azureml_endpoint import (
AzureMLChatOnlineEndpoint,
LlamaContentFormatter,
)
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
)
def test_llama_call() -> None:
"""Test valid call to Open Source Foundation Model."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
response = chat(messages=[HumanMessage(content="Foo")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_timeout_kwargs() -> None:
"""Test that timeout kwarg works."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
response = chat(messages=[HumanMessage(content="FOO")], timeout=60)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_message_history() -> None:
"""Test that multiple messages works."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
response = chat(
messages=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_messages() -> None:
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
message = HumanMessage(content="Hi!")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

View File

@@ -18,8 +18,8 @@ from langchain.llms.azureml_endpoint import (
from langchain.llms.loading import load_llm
def test_oss_call() -> None:
"""Test valid call to Open Source Foundation Model."""
def test_gpt2_call() -> None:
"""Test valid call to GPT2."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
@@ -43,7 +43,7 @@ def test_hf_call() -> None:
def test_dolly_call() -> None:
"""Test valid call to dolly-v2-12b."""
"""Test valid call to dolly-v2."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"),