mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 02:13:23 +00:00
feat(llms): support ernie chat (#9114)
Description: support ernie (文心一言) chat model Related issue: #7990 Dependencies: None Tag maintainer: @baskaryan
This commit is contained in:
parent
08a8363fc6
commit
664ff28cba
87
docs/extras/integrations/chat/ernie.ipynb
Normal file
87
docs/extras/integrations/chat/ernie.ipynb
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ERNIE-Bot Chat\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook covers how to get started with Ernie chat models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import ErnieBotChat\n",
|
||||||
|
"from langchain.schema import AIMessage, HumanMessage, SystemMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ErnieBotChat(ernie_client_id='YOUR_CLIENT_ID', ernie_client_secret='YOUR_CLIENT_SECRET')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"or you can set `client_id` and `client_secret` in your environment variables\n",
|
||||||
|
"```bash\n",
|
||||||
|
"export ERNIE_CLIENT_ID=YOUR_CLIENT_ID\n",
|
||||||
|
"export ERNIE_CLIENT_SECRET=YOUR_CLIENT_SECRET\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content='Hello, I am an artificial intelligence language model. My purpose is to help users answer questions or provide information. What can I do for you?', additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat([\n",
|
||||||
|
" HumanMessage(content='hello there, who are you?')\n",
|
||||||
|
"])"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.4"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
|
|||||||
from langchain.chat_models.anthropic import ChatAnthropic
|
from langchain.chat_models.anthropic import ChatAnthropic
|
||||||
from langchain.chat_models.anyscale import ChatAnyscale
|
from langchain.chat_models.anyscale import ChatAnyscale
|
||||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
|
from langchain.chat_models.ernie import ErnieBotChat
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||||
from langchain.chat_models.human import HumanInputChatModel
|
from langchain.chat_models.human import HumanInputChatModel
|
||||||
@ -43,4 +44,5 @@ __all__ = [
|
|||||||
"HumanInputChatModel",
|
"HumanInputChatModel",
|
||||||
"ChatAnyscale",
|
"ChatAnyscale",
|
||||||
"ChatLiteLLM",
|
"ChatLiteLLM",
|
||||||
|
"ErnieBotChat",
|
||||||
]
|
]
|
||||||
|
156
libs/langchain/langchain/chat_models/ernie.py
Normal file
156
libs/langchain/langchain/chat_models/ernie.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatGeneration,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResult,
|
||||||
|
HumanMessage,
|
||||||
|
)
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieBotChat(BaseChatModel):
|
||||||
|
"""ErnieBot Chat large language model.
|
||||||
|
|
||||||
|
ERNIE-Bot is a large language model developed by Baidu,
|
||||||
|
covering a huge amount of Chinese data.
|
||||||
|
|
||||||
|
To use, you should have the `ernie_client_id` and `ernie_client_secret` set.
|
||||||
|
|
||||||
|
or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
|
||||||
|
|
||||||
|
access_token will be automatically generated based on client_id and client_secret,
|
||||||
|
and will be regenerated after expiration.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ErnieBotChat
|
||||||
|
chat = ErnieBotChat()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
ernie_client_id: Optional[str] = None
|
||||||
|
ernie_client_secret: Optional[str] = None
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
|
||||||
|
model_name: str = "ERNIE-Bot-turbo"
|
||||||
|
|
||||||
|
streaming: Optional[bool] = False
|
||||||
|
top_p: Optional[float] = 0.8
|
||||||
|
temperature: Optional[float] = 0.95
|
||||||
|
penalty_score: Optional[float] = 1
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
values["ernie_client_id"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"ernie_client_id",
|
||||||
|
"ERNIE_CLIENT_ID",
|
||||||
|
)
|
||||||
|
values["ernie_client_secret"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"ernie_client_secret",
|
||||||
|
"ERNIE_CLIENT_SECRET",
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _chat(self, payload: object) -> dict:
|
||||||
|
base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
|
||||||
|
if self.model_name == "ERNIE-Bot-turbo":
|
||||||
|
url = f"{base_url}/eb-instant"
|
||||||
|
elif self.model_name == "ERNIE-Bot":
|
||||||
|
url = f"{base_url}/completions"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown model_name {self.model_name}")
|
||||||
|
resp = requests.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
params={"access_token": self.access_token},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def _refresh_access_token_with_lock(self) -> None:
|
||||||
|
with self._lock:
|
||||||
|
logger.debug("Refreshing access token")
|
||||||
|
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
|
||||||
|
resp = requests.post(
|
||||||
|
base_url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
params={
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"client_id": self.ernie_client_id,
|
||||||
|
"client_secret": self.ernie_client_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.access_token = str(resp.json().get("access_token"))
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
raise ValueError("`streaming` option currently unsupported.")
|
||||||
|
|
||||||
|
if not self.access_token:
|
||||||
|
self._refresh_access_token_with_lock()
|
||||||
|
payload = {
|
||||||
|
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"penalty_score": self.penalty_score,
|
||||||
|
}
|
||||||
|
resp = self._chat(payload)
|
||||||
|
if resp.get("error_code"):
|
||||||
|
if resp.get("error_code") == 111:
|
||||||
|
self._refresh_access_token_with_lock()
|
||||||
|
resp = self._chat(payload)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Error from ErnieChat api response: {resp}")
|
||||||
|
return self._create_chat_result(resp)
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = [
|
||||||
|
ChatGeneration(message=AIMessage(content=response.get("result")))
|
||||||
|
]
|
||||||
|
token_usage = response.get("usage", {})
|
||||||
|
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "ernie-chat"
|
@ -0,0 +1,26 @@
|
|||||||
|
from langchain.chat_models.ernie import ErnieBotChat
|
||||||
|
from langchain.schema.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_ernie_bot() -> None:
|
||||||
|
chat = ErnieBotChat()
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_ernie_bot_with_model_name() -> None:
|
||||||
|
chat = ErnieBotChat(model_name="ERNIE-Bot")
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_ernie_bot_with_temperature() -> None:
|
||||||
|
chat = ErnieBotChat(model_name="ERNIE-Bot", temperature=1.0)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
16
libs/langchain/tests/unit_tests/chat_models/test_ernie.py
Normal file
16
libs/langchain/tests/unit_tests/chat_models/test_ernie.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from langchain.chat_models.ernie import _convert_message_to_dict
|
||||||
|
from langchain.schema.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_human() -> None:
|
||||||
|
message = HumanMessage(content="foo")
|
||||||
|
result = _convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "user", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_ai() -> None:
|
||||||
|
message = AIMessage(content="foo")
|
||||||
|
result = _convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "assistant", "content": "foo"}
|
||||||
|
assert result == expected_output
|
Loading…
Reference in New Issue
Block a user