mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
Add LLaMa Formatter and AzureML Chat Endpoint (#8382)
## Description Microsoft and Meta recently [announced their collaboration](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/) on LLaMa2. This PR extends the current LLM wrapper and introduces a new Chat Model wrapper for AzureML to support LLaMa2. ## Dependencies No dependencies added :) ## Twitter Handles [@matthew_d13](https://twitter.com/matthew_d13) [@prakhar_in](https://twitter.com/prakhar_in) maintainers - @hwchase17, @baskaryan
This commit is contained in:
151
libs/langchain/langchain/chat_models/azureml_endpoint.py
Normal file
151
libs/langchain/langchain/chat_models/azureml_endpoint.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for LLaMa"""
|
||||
|
||||
SUPPORTED_ROLES = ["user", "assistant", "system"]
|
||||
|
||||
@staticmethod
|
||||
def _convert_message_to_dict(message: BaseMessage) -> Dict:
|
||||
"""Converts message to a dict according to role"""
|
||||
if isinstance(message, HumanMessage):
|
||||
return {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
return {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
return {"role": "system", "content": message.content}
|
||||
elif (
|
||||
isinstance(message, ChatMessage)
|
||||
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
|
||||
):
|
||||
return {"role": message.role, "content": message.content}
|
||||
else:
|
||||
supported = ",".join(
|
||||
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
|
||||
)
|
||||
raise ValueError(
|
||||
f"""Received unsupported role.
|
||||
Supported roles for the LLaMa Foundation Model: {supported}"""
|
||||
)
|
||||
|
||||
def _format_request_payload(
|
||||
self, messages: List[BaseMessage], model_kwargs: Dict
|
||||
) -> bytes:
|
||||
chat_messages = [
|
||||
LlamaContentFormatter._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
prompt = json.dumps(
|
||||
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
|
||||
)
|
||||
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
"""Formats the request according the the chosen api"""
|
||||
return str.encode(prompt)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
"""Formats response"""
|
||||
return json.loads(output)["output"]
|
||||
|
||||
|
||||
class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
"""Azure ML Chat Online Endpoint models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
azure_chat = AzureMLChatOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=content_formatter,
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint_url: str = ""
|
||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_URL`."""
|
||||
|
||||
endpoint_api_key: str = ""
|
||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
http_client: Any = None #: :meta private:
|
||||
|
||||
content_formatter: Any = None
|
||||
"""The content formatter that provides an input and output
|
||||
transform function to handle formats between the LLM and
|
||||
the endpoint"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
@validator("http_client", always=True, allow_reuse=True)
|
||||
@classmethod
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
endpoint_key = get_from_dict_or_env(
|
||||
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
|
||||
)
|
||||
endpoint_url = get_from_dict_or_env(
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key)
|
||||
return http_client
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "azureml_chat_endpoint"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
Args:
|
||||
messages: The messages in the conversation with the chat model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = azureml_model("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
request_payload = self.content_formatter._format_request_payload(
|
||||
messages, _model_kwargs
|
||||
)
|
||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload
|
||||
)
|
||||
return generated_text
|
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import urllib.request
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
@@ -14,16 +15,19 @@ class AzureMLEndpointClient(object):
|
||||
"""AzureML Managed Endpoint client."""
|
||||
|
||||
def __init__(
|
||||
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str
|
||||
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = ""
|
||||
) -> None:
|
||||
"""Initialize the class."""
|
||||
if not endpoint_api_key:
|
||||
raise ValueError("A key should be provided to invoke the endpoint")
|
||||
if not endpoint_api_key or not endpoint_url:
|
||||
raise ValueError(
|
||||
"""A key/token and REST endpoint should
|
||||
be provided to invoke the endpoint"""
|
||||
)
|
||||
self.endpoint_url = endpoint_url
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
self.deployment_name = deployment_name
|
||||
|
||||
def call(self, body: bytes) -> bytes:
|
||||
def call(self, body: bytes, **kwargs: Any) -> bytes:
|
||||
"""call."""
|
||||
|
||||
# The azureml-model-deployment header will force the request to go to a
|
||||
@@ -32,11 +36,12 @@ class AzureMLEndpointClient(object):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": ("Bearer " + self.endpoint_api_key),
|
||||
"azureml-model-deployment": self.deployment_name,
|
||||
}
|
||||
if self.deployment_name != "":
|
||||
headers["azureml-model-deployment"] = self.deployment_name
|
||||
|
||||
req = urllib.request.Request(self.endpoint_url, body, headers)
|
||||
response = urllib.request.urlopen(req, timeout=50)
|
||||
response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50))
|
||||
result = response.read()
|
||||
return result
|
||||
|
||||
@@ -75,7 +80,26 @@ class ContentFormatterBase:
|
||||
"""The MIME type of the input data passed to the endpoint"""
|
||||
|
||||
accepts: Optional[str] = "application/json"
|
||||
"""The MIME type of the response data returned form the endpoint"""
|
||||
"""The MIME type of the response data returned from the endpoint"""
|
||||
|
||||
@staticmethod
|
||||
def escape_special_characters(prompt: str) -> str:
|
||||
"""Escapes any special characters in `prompt`"""
|
||||
escape_map = {
|
||||
"\\": "\\\\",
|
||||
'"': '\\"',
|
||||
"\b": "\\b",
|
||||
"\f": "\\f",
|
||||
"\n": "\\n",
|
||||
"\r": "\\r",
|
||||
"\t": "\\t",
|
||||
}
|
||||
|
||||
# Replace each occurrence of the specified characters with escaped versions
|
||||
for escape_sequence, escaped_sequence in escape_map.items():
|
||||
prompt = prompt.replace(escape_sequence, escaped_sequence)
|
||||
|
||||
return prompt
|
||||
|
||||
@abstractmethod
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@@ -92,44 +116,86 @@ class ContentFormatterBase:
|
||||
"""
|
||||
|
||||
|
||||
class OSSContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for LLMs from the OSS catalog."""
|
||||
class GPT2ContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for GPT2"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps(
|
||||
{"inputs": {"input_string": [prompt]}, "parameters": model_kwargs}
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
|
||||
)
|
||||
return str.encode(input_str)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
response_json = json.loads(output)
|
||||
return response_json[0]["0"]
|
||||
return json.loads(output)[0]["0"]
|
||||
|
||||
|
||||
class OSSContentFormatter(GPT2ContentFormatter):
|
||||
"""Deprecated: Kept for backwards compatibility
|
||||
|
||||
Content handler for LLMs from the OSS catalog."""
|
||||
|
||||
content_formatter: Any = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"""`OSSContentFormatter` will be deprecated in the future.
|
||||
Please use `GPT2ContentFormatter` instead.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class HFContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for LLMs from the HuggingFace catalog."""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({"inputs": [prompt], "parameters": model_kwargs})
|
||||
return str.encode(input_str)
|
||||
ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
response_json = json.loads(output)
|
||||
return response_json[0][0]["generated_text"]
|
||||
return json.loads(output)[0]["generated_text"]
|
||||
|
||||
|
||||
class DollyContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for the Dolly-v2-12b model"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps(
|
||||
{"input_data": {"input_string": [prompt]}, "parameters": model_kwargs}
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {"input_string": [f'"{prompt}"']},
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
)
|
||||
return str.encode(input_str)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
response_json = json.loads(output)
|
||||
return response_json[0]
|
||||
return json.loads(output)[0]
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for LLaMa"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
"""Formats the request according the the chosen api"""
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": [f'"{prompt}"'],
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
}
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
"""Formats response"""
|
||||
return json.loads(output)[0]["0"]
|
||||
|
||||
|
||||
class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
@@ -138,10 +204,9 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
azure_llm = AzureMLModel(
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_key="my-api-key",
|
||||
deployment_name="my-deployment-name",
|
||||
content_formatter=content_formatter,
|
||||
)
|
||||
""" # noqa: E501
|
||||
@@ -155,8 +220,8 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
deployment_name: str = ""
|
||||
"""Deployment Name for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_DEPLOYMENT_NAME`."""
|
||||
"""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
|
||||
to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
|
||||
|
||||
http_client: Any = None #: :meta private:
|
||||
|
||||
@@ -179,7 +244,7 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
deployment_name = get_from_dict_or_env(
|
||||
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME"
|
||||
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
|
||||
)
|
||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
|
||||
return http_client
|
||||
@@ -203,7 +268,7 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
Args:
|
||||
@@ -217,7 +282,11 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
body = self.content_formatter.format_request_payload(prompt, _model_kwargs)
|
||||
endpoint_response = self.http_client.call(body)
|
||||
response = self.content_formatter.format_response_payload(endpoint_response)
|
||||
return response
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
prompt, _model_kwargs
|
||||
)
|
||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload
|
||||
)
|
||||
return generated_text
|
||||
|
@@ -0,0 +1,58 @@
|
||||
"""Test AzureML Chat Endpoint wrapper."""
|
||||
|
||||
from langchain.chat_models.azureml_endpoint import (
|
||||
AzureMLChatOnlineEndpoint,
|
||||
LlamaContentFormatter,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
|
||||
def test_llama_call() -> None:
|
||||
"""Test valid call to Open Source Foundation Model."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(messages=[HumanMessage(content="Foo")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_timeout_kwargs() -> None:
|
||||
"""Test that timeout kwarg works."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(messages=[HumanMessage(content="FOO")], timeout=60)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_message_history() -> None:
|
||||
"""Test that multiple messages works."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
message = HumanMessage(content="Hi!")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
@@ -18,8 +18,8 @@ from langchain.llms.azureml_endpoint import (
|
||||
from langchain.llms.loading import load_llm
|
||||
|
||||
|
||||
def test_oss_call() -> None:
|
||||
"""Test valid call to Open Source Foundation Model."""
|
||||
def test_gpt2_call() -> None:
|
||||
"""Test valid call to GPT2."""
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
@@ -43,7 +43,7 @@ def test_hf_call() -> None:
|
||||
|
||||
|
||||
def test_dolly_call() -> None:
|
||||
"""Test valid call to dolly-v2-12b."""
|
||||
"""Test valid call to dolly-v2."""
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"),
|
||||
|
Reference in New Issue
Block a user