From a566a15930f15d6b117c2f3517ef0f8687520c3b Mon Sep 17 00:00:00 2001 From: maang-h <55082429+maang-h@users.noreply.github.com> Date: Tue, 27 Aug 2024 01:33:22 +0800 Subject: [PATCH] Fix MoonshotChat instantiate with alias (#25755) - **Description:** - Fix `MoonshotChat` instantiate with alias - Add `MoonshotChat` to `__init__.py` --------- Co-authored-by: Chester Curme --- .../langchain_community/chat_models/__init__.py | 5 +++++ .../langchain_community/chat_models/moonshot.py | 6 +++++- .../integration_tests/chat_models/test_moonshot.py | 10 +++++++++- .../tests/unit_tests/chat_models/test_imports.py | 1 + 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 9658c707e67..b29554a5d00 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -122,6 +122,9 @@ if TYPE_CHECKING: from langchain_community.chat_models.mlx import ( ChatMLX, ) + from langchain_community.chat_models.moonshot import ( + MoonshotChat, + ) from langchain_community.chat_models.oci_generative_ai import ( ChatOCIGenAI, # noqa: F401 ) @@ -224,6 +227,7 @@ __all__ = [ "JinaChat", "LlamaEdgeChatService", "MiniMaxChat", + "MoonshotChat", "PaiEasChatEndpoint", "PromptLayerChatOpenAI", "QianfanChatEndpoint", @@ -280,6 +284,7 @@ _module_lookup = { "JinaChat": "langchain_community.chat_models.jinachat", "LlamaEdgeChatService": "langchain_community.chat_models.llama_edge", "MiniMaxChat": "langchain_community.chat_models.minimax", + "MoonshotChat": "langchain_community.chat_models.moonshot", "PaiEasChatEndpoint": "langchain_community.chat_models.pai_eas_endpoint", "PromptLayerChatOpenAI": "langchain_community.chat_models.promptlayer_openai", "SolarChat": "langchain_community.chat_models.solar", diff --git a/libs/community/langchain_community/chat_models/moonshot.py b/libs/community/langchain_community/chat_models/moonshot.py index fd8a455f888..cd9de1c51dd 100644 --- a/libs/community/langchain_community/chat_models/moonshot.py +++ b/libs/community/langchain_community/chat_models/moonshot.py @@ -33,7 +33,11 @@ class MoonshotChat(MoonshotCommon, ChatOpenAI): # type: ignore[misc] def validate_environment(cls, values: Dict) -> Dict: """Validate that the environment is set up correctly.""" values["moonshot_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "moonshot_api_key", "MOONSHOT_API_KEY") + get_from_dict_or_env( + values, + ["moonshot_api_key", "api_key", "openai_api_key"], + "MOONSHOT_API_KEY", + ) ) try: diff --git a/libs/community/tests/integration_tests/chat_models/test_moonshot.py b/libs/community/tests/integration_tests/chat_models/test_moonshot.py index 58ea6de5e02..bb29175a2d0 100644 --- a/libs/community/tests/integration_tests/chat_models/test_moonshot.py +++ b/libs/community/tests/integration_tests/chat_models/test_moonshot.py @@ -1,9 +1,10 @@ """Test Moonshot Chat Model.""" -from typing import Type +from typing import Type, cast import pytest from langchain_core.language_models import BaseChatModel +from langchain_core.pydantic_v1 import SecretStr from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_community.chat_models.moonshot import MoonshotChat @@ -21,3 +22,10 @@ class TestMoonshotChat(ChatModelIntegrationTests): @pytest.mark.xfail(reason="Not yet implemented.") def test_usage_metadata(self, model: BaseChatModel) -> None: super().test_usage_metadata(model) + + +def test_chat_moonshot_instantiate_with_alias() -> None: + """Test MoonshotChat instantiate when using alias.""" + api_key = "your-api-key" + chat = MoonshotChat(api_key=api_key) # type: ignore[call-arg] + assert cast(SecretStr, chat.moonshot_api_key).get_secret_value() == api_key diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 9e97444484b..50a61d4e51a 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -48,6 +48,7 @@ EXPECTED_ALL = [ "JinaChat", "LlamaEdgeChatService", "MiniMaxChat", + "MoonshotChat", "PaiEasChatEndpoint", "PromptLayerChatOpenAI", "SolarChat",