mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 13:55:03 +00:00
fix(llms): improve the ernie chat model (#9289)
- Description: improve the ernie chat model. - fix missing kwargs to payload - new test cases - add some debug level log - improve description - Issue: None - Dependencies: None - Tag maintainer: @baskaryan
This commit is contained in:
parent
1d55141c50
commit
63601551b1
@ -6,7 +6,8 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# ERNIE-Bot Chat\n",
|
"# ERNIE-Bot Chat\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook covers how to get started with Ernie chat models."
|
"[ERNIE-Bot](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11) is a large language model developed by Baidu, covering a huge amount of Chinese data.\n",
|
||||||
|
"This notebook covers how to get started with ErnieBot chat models."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -16,7 +17,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.chat_models import ErnieBotChat\n",
|
"from langchain.chat_models import ErnieBotChat\n",
|
||||||
"from langchain.schema import AIMessage, HumanMessage, SystemMessage"
|
"from langchain.schema import HumanMessage"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -33,23 +33,26 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
class ErnieBotChat(BaseChatModel):
|
class ErnieBotChat(BaseChatModel):
|
||||||
"""ErnieBot Chat large language model.
|
"""`ERNIE-Bot` large language model.
|
||||||
|
|
||||||
ERNIE-Bot is a large language model developed by Baidu,
|
ERNIE-Bot is a large language model developed by Baidu,
|
||||||
covering a huge amount of Chinese data.
|
covering a huge amount of Chinese data.
|
||||||
|
|
||||||
To use, you should have the `ernie_client_id` and `ernie_client_secret` set.
|
To use, you should have the `ernie_client_id` and `ernie_client_secret` set,
|
||||||
|
|
||||||
or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
|
or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
|
||||||
|
|
||||||
|
Note:
|
||||||
access_token will be automatically generated based on client_id and client_secret,
|
access_token will be automatically generated based on client_id and client_secret,
|
||||||
and will be regenerated after expiration.
|
and will be regenerated after expiration (30 days).
|
||||||
|
|
||||||
|
Default model is `ERNIE-Bot-turbo`,
|
||||||
|
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.chat_models import ErnieBotChat
|
from langchain.chat_models import ErnieBotChat
|
||||||
chat = ErnieBotChat()
|
chat = ErnieBotChat(model_name='ERNIE-Bot')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -133,10 +136,13 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"penalty_score": self.penalty_score,
|
"penalty_score": self.penalty_score,
|
||||||
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
logger.debug(f"Payload for ernie api is {payload}")
|
||||||
resp = self._chat(payload)
|
resp = self._chat(payload)
|
||||||
if resp.get("error_code"):
|
if resp.get("error_code"):
|
||||||
if resp.get("error_code") == 111:
|
if resp.get("error_code") == 111:
|
||||||
|
logger.debug("access_token expired, refresh it")
|
||||||
self._refresh_access_token_with_lock()
|
self._refresh_access_token_with_lock()
|
||||||
resp = self._chat(payload)
|
resp = self._chat(payload)
|
||||||
else:
|
else:
|
||||||
@ -153,4 +159,4 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "ernie-chat"
|
return "ernie-bot-chat"
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.chat_models.ernie import ErnieBotChat
|
from langchain.chat_models.ernie import ErnieBotChat
|
||||||
from langchain.schema.messages import AIMessage, HumanMessage
|
from langchain.schema.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
@ -24,3 +26,31 @@ def test_chat_ernie_bot_with_temperature() -> None:
|
|||||||
response = chat([message])
|
response = chat([message])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_ernie_bot_with_kwargs() -> None:
|
||||||
|
chat = ErnieBotChat()
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat([message], temperature=0.88, top_p=0.7)
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extra_kwargs() -> None:
|
||||||
|
chat = ErnieBotChat(temperature=0.88, top_p=0.7)
|
||||||
|
assert chat.temperature == 0.88
|
||||||
|
assert chat.top_p == 0.7
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_temperature_1() -> None:
|
||||||
|
chat = ErnieBotChat()
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chat([message], temperature=1.2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_temperature_2() -> None:
|
||||||
|
chat = ErnieBotChat()
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chat([message], temperature=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user