mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
fix(llms): improve the ernie chat model (#9289)
- Description: improve the ernie chat model. - fix missing kwargs to payload - new test cases - add some debug level log - improve description - Issue: None - Dependencies: None - Tag maintainer: @baskaryan
This commit is contained in:
parent
1d55141c50
commit
63601551b1
@ -6,7 +6,8 @@
|
||||
"source": [
|
||||
"# ERNIE-Bot Chat\n",
|
||||
"\n",
|
||||
"This notebook covers how to get started with Ernie chat models."
|
||||
"[ERNIE-Bot](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11) is a large language model developed by Baidu, covering a huge amount of Chinese data.\n",
|
||||
"This notebook covers how to get started with ErnieBot chat models."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -16,7 +17,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ErnieBotChat\n",
|
||||
"from langchain.schema import AIMessage, HumanMessage, SystemMessage"
|
||||
"from langchain.schema import HumanMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -33,23 +33,26 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
|
||||
class ErnieBotChat(BaseChatModel):
|
||||
"""ErnieBot Chat large language model.
|
||||
"""`ERNIE-Bot` large language model.
|
||||
|
||||
ERNIE-Bot is a large language model developed by Baidu,
|
||||
covering a huge amount of Chinese data.
|
||||
|
||||
To use, you should have the `ernie_client_id` and `ernie_client_secret` set.
|
||||
|
||||
To use, you should have the `ernie_client_id` and `ernie_client_secret` set,
|
||||
or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
|
||||
|
||||
Note:
|
||||
access_token will be automatically generated based on client_id and client_secret,
|
||||
and will be regenerated after expiration.
|
||||
and will be regenerated after expiration (30 days).
|
||||
|
||||
Default model is `ERNIE-Bot-turbo`,
|
||||
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import ErnieBotChat
|
||||
chat = ErnieBotChat()
|
||||
chat = ErnieBotChat(model_name='ERNIE-Bot')
|
||||
|
||||
"""
|
||||
|
||||
@ -133,10 +136,13 @@ class ErnieBotChat(BaseChatModel):
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
**kwargs,
|
||||
}
|
||||
logger.debug(f"Payload for ernie api is {payload}")
|
||||
resp = self._chat(payload)
|
||||
if resp.get("error_code"):
|
||||
if resp.get("error_code") == 111:
|
||||
logger.debug("access_token expired, refresh it")
|
||||
self._refresh_access_token_with_lock()
|
||||
resp = self._chat(payload)
|
||||
else:
|
||||
@ -153,4 +159,4 @@ class ErnieBotChat(BaseChatModel):
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "ernie-chat"
|
||||
return "ernie-bot-chat"
|
||||
|
@ -1,3 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.ernie import ErnieBotChat
|
||||
from langchain.schema.messages import AIMessage, HumanMessage
|
||||
|
||||
@ -24,3 +26,31 @@ def test_chat_ernie_bot_with_temperature() -> None:
|
||||
response = chat([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_ernie_bot_with_kwargs() -> None:
|
||||
chat = ErnieBotChat()
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message], temperature=0.88, top_p=0.7)
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_extra_kwargs() -> None:
|
||||
chat = ErnieBotChat(temperature=0.88, top_p=0.7)
|
||||
assert chat.temperature == 0.88
|
||||
assert chat.top_p == 0.7
|
||||
|
||||
|
||||
def test_wrong_temperature_1() -> None:
|
||||
chat = ErnieBotChat()
|
||||
message = HumanMessage(content="Hello")
|
||||
with pytest.raises(ValueError):
|
||||
chat([message], temperature=1.2)
|
||||
|
||||
|
||||
def test_wrong_temperature_2() -> None:
|
||||
chat = ErnieBotChat()
|
||||
message = HumanMessage(content="Hello")
|
||||
with pytest.raises(ValueError):
|
||||
chat([message], temperature=0)
|
||||
|
Loading…
Reference in New Issue
Block a user