From f355a98bb6b09675ed50d45a6935968dc227f3df Mon Sep 17 00:00:00 2001 From: Austin Burdette Date: Fri, 23 Aug 2024 13:56:19 -0400 Subject: [PATCH] community:yuan2[patch]: standardize init args (#21462) updated stop and request_timeout so they aliased to stop_sequences, and timeout respectively. Added test that both continue to set the same underlying attributes. Related to [20085](https://github.com/langchain-ai/langchain/issues/20085) Co-authored-by: ccurme --- .../langchain_community/chat_models/yuan2.py | 6 ++++-- .../tests/unit_tests/chat_models/test_yuan2.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/chat_models/yuan2.py b/libs/community/langchain_community/chat_models/yuan2.py index df95e9902dc..dc63eb83d24 100644 --- a/libs/community/langchain_community/chat_models/yuan2.py +++ b/libs/community/langchain_community/chat_models/yuan2.py @@ -93,7 +93,9 @@ class ChatYuan2(BaseChatModel): ) """Base URL path for API requests, an OpenAI compatible API server.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = None + request_timeout: Optional[Union[float, Tuple[float, float]]] = Field( + default=None, alias="timeout" + ) """Timeout for requests to yuan2 completion API. Default is 600 seconds.""" max_retries: int = 6 @@ -111,7 +113,7 @@ class ChatYuan2(BaseChatModel): top_p: Optional[float] = 0.9 """The top-p value to use for sampling.""" - stop: Optional[List[str]] = [""] + stop: Optional[List[str]] = Field(default=[""], alias="stop_sequences") """A list of strings to stop generation when encountered.""" repeat_last_n: Optional[int] = 64 diff --git a/libs/community/tests/unit_tests/chat_models/test_yuan2.py b/libs/community/tests/unit_tests/chat_models/test_yuan2.py index 74b2fb84cf5..683b2a013c7 100644 --- a/libs/community/tests/unit_tests/chat_models/test_yuan2.py +++ b/libs/community/tests/unit_tests/chat_models/test_yuan2.py @@ -22,6 +22,22 @@ def test_yuan2_model_param() -> None: assert chat.model_name == "foo" +@pytest.mark.requires("openai") +def test_yuan2_timeout_param() -> None: + chat = ChatYuan2(request_timeout=5) # type: ignore[call-arg] + assert chat.request_timeout == 5 + chat = ChatYuan2(timeout=10) # type: ignore[call-arg] + assert chat.request_timeout == 10 + + +@pytest.mark.requires("openai") +def test_yuan2_stop_sequences_param() -> None: + chat = ChatYuan2(stop=[""]) # type: ignore[call-arg] + assert chat.stop == [""] + chat = ChatYuan2(stop_sequences=[""]) # type: ignore[call-arg] + assert chat.stop == [""] + + def test__convert_message_to_dict_human() -> None: message = HumanMessage(content="foo") result = _convert_message_to_dict(message)