mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
feat(llms): improve ERNIE-Bot chat model (#9833)
- Description: improve ERNIE-Bot chat model, add request timeout and more testcases. - Issue: None - Dependencies: None - Tag maintainer: @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
bdccb1215a
commit
ffa5625134
@ -57,12 +57,25 @@ class ErnieBotChat(BaseChatModel):
|
||||
"""
|
||||
|
||||
ernie_client_id: Optional[str] = None
|
||||
"""Baidu application client id"""
|
||||
|
||||
ernie_client_secret: Optional[str] = None
|
||||
"""Baidu application client secret"""
|
||||
|
||||
access_token: Optional[str] = None
|
||||
"""access token is generated by client id and client secret,
|
||||
setting this value directly will cause an error"""
|
||||
|
||||
model_name: str = "ERNIE-Bot-turbo"
|
||||
"""model name of ernie, default is `ERNIE-Bot-turbo`.
|
||||
Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""
|
||||
|
||||
request_timeout: Optional[int] = 60
|
||||
"""request timeout for chat http requests"""
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""streaming mode. not supported yet."""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
temperature: Optional[float] = 0.95
|
||||
penalty_score: Optional[float] = 1
|
||||
@ -93,6 +106,7 @@ class ErnieBotChat(BaseChatModel):
|
||||
raise ValueError(f"Got unknown model_name {self.model_name}")
|
||||
resp = requests.post(
|
||||
url,
|
||||
timeout=self.request_timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
@ -107,6 +121,7 @@ class ErnieBotChat(BaseChatModel):
|
||||
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
resp = requests.post(
|
||||
base_url,
|
||||
timeout=10,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
|
@ -45,12 +45,14 @@ def test_extra_kwargs() -> None:
|
||||
def test_wrong_temperature_1() -> None:
|
||||
chat = ErnieBotChat()
|
||||
message = HumanMessage(content="Hello")
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError) as e:
|
||||
chat([message], temperature=1.2)
|
||||
assert "parameter check failed, temperature range is (0, 1.0]" in str(e)
|
||||
|
||||
|
||||
def test_wrong_temperature_2() -> None:
|
||||
chat = ErnieBotChat()
|
||||
message = HumanMessage(content="Hello")
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError) as e:
|
||||
chat([message], temperature=0)
|
||||
assert "parameter check failed, temperature range is (0, 1.0]" in str(e)
|
||||
|
@ -1,5 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.ernie import _convert_message_to_dict
|
||||
from langchain.schema.messages import AIMessage, HumanMessage
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
def test__convert_dict_to_message_human() -> None:
|
||||
@ -14,3 +21,17 @@ def test__convert_dict_to_message_ai() -> None:
|
||||
result = _convert_message_to_dict(message)
|
||||
expected_output = {"role": "assistant", "content": "foo"}
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_system() -> None:
|
||||
message = SystemMessage(content="foo")
|
||||
with pytest.raises(ValueError) as e:
|
||||
_convert_message_to_dict(message)
|
||||
assert "Got unknown type" in str(e)
|
||||
|
||||
|
||||
def test__convert_dict_to_message_function() -> None:
|
||||
message = FunctionMessage(name="foo", content="bar")
|
||||
with pytest.raises(ValueError) as e:
|
||||
_convert_message_to_dict(message)
|
||||
assert "Got unknown type" in str(e)
|
||||
|
Loading…
Reference in New Issue
Block a user