From ffa56251347c1a4e417a2f7299f73e299b63bace Mon Sep 17 00:00:00 2001 From: axiangcoding Date: Wed, 30 Aug 2023 09:20:06 +0800 Subject: [PATCH] feat(llms): improve ERNIE-Bot chat model (#9833) - Description: improve ERNIE-Bot chat model, add request timeout and more testcases. - Issue: None - Dependencies: None - Tag maintainer: @baskaryan --------- Co-authored-by: Bagatur --- libs/langchain/langchain/chat_models/ernie.py | 15 ++++++++++++ .../chat_models/test_ernie.py | 6 +++-- .../unit_tests/chat_models/test_ernie.py | 23 ++++++++++++++++++- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py index 1ccecf36fe6..d3fdce5c31f 100644 --- a/libs/langchain/langchain/chat_models/ernie.py +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -57,12 +57,25 @@ class ErnieBotChat(BaseChatModel): """ ernie_client_id: Optional[str] = None + """Baidu application client id""" + ernie_client_secret: Optional[str] = None + """Baidu application client secret""" + access_token: Optional[str] = None + """access token is generated by client id and client secret, + setting this value directly will cause an error""" model_name: str = "ERNIE-Bot-turbo" + """model name of ernie, default is `ERNIE-Bot-turbo`. + Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`""" + + request_timeout: Optional[int] = 60 + """request timeout for chat http requests""" streaming: Optional[bool] = False + """streaming mode. not supported yet.""" + top_p: Optional[float] = 0.8 temperature: Optional[float] = 0.95 penalty_score: Optional[float] = 1 @@ -93,6 +106,7 @@ class ErnieBotChat(BaseChatModel): raise ValueError(f"Got unknown model_name {self.model_name}") resp = requests.post( url, + timeout=self.request_timeout, headers={ "Content-Type": "application/json", }, @@ -107,6 +121,7 @@ class ErnieBotChat(BaseChatModel): base_url: str = "https://aip.baidubce.com/oauth/2.0/token" resp = requests.post( base_url, + timeout=10, headers={ "Content-Type": "application/json", "Accept": "application/json", diff --git a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py b/libs/langchain/tests/integration_tests/chat_models/test_ernie.py index a8a80ed9bbd..4b79f40d093 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_ernie.py @@ -45,12 +45,14 @@ def test_extra_kwargs() -> None: def test_wrong_temperature_1() -> None: chat = ErnieBotChat() message = HumanMessage(content="Hello") - with pytest.raises(ValueError): + with pytest.raises(ValueError) as e: chat([message], temperature=1.2) + assert "parameter check failed, temperature range is (0, 1.0]" in str(e) def test_wrong_temperature_2() -> None: chat = ErnieBotChat() message = HumanMessage(content="Hello") - with pytest.raises(ValueError): + with pytest.raises(ValueError) as e: chat([message], temperature=0) + assert "parameter check failed, temperature range is (0, 1.0]" in str(e) diff --git a/libs/langchain/tests/unit_tests/chat_models/test_ernie.py b/libs/langchain/tests/unit_tests/chat_models/test_ernie.py index 06eb1019615..a8417017f5a 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_ernie.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_ernie.py @@ -1,5 +1,12 @@ +import pytest + from langchain.chat_models.ernie import _convert_message_to_dict -from langchain.schema.messages import AIMessage, HumanMessage +from langchain.schema.messages import ( + AIMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) def test__convert_dict_to_message_human() -> None: @@ -14,3 +21,17 @@ def test__convert_dict_to_message_ai() -> None: result = _convert_message_to_dict(message) expected_output = {"role": "assistant", "content": "foo"} assert result == expected_output + + +def test__convert_dict_to_message_system() -> None: + message = SystemMessage(content="foo") + with pytest.raises(ValueError) as e: + _convert_message_to_dict(message) + assert "Got unknown type" in str(e) + + +def test__convert_dict_to_message_function() -> None: + message = FunctionMessage(name="foo", content="bar") + with pytest.raises(ValueError) as e: + _convert_message_to_dict(message) + assert "Got unknown type" in str(e)