From d07885f8b7a3632ace1eda7c760f53249f18f378 Mon Sep 17 00:00:00 2001 From: Param Singh Date: Mon, 20 May 2024 17:11:36 -0700 Subject: [PATCH] community[patch]: standardized sparkllm init args (#21633) Related to #20085 @baskaryan Thank you for contributing to LangChain! community:sparkllm[patch]: standardized init args updated `spark_api_key` so that aliased to `api_key`. Added integration test for `sparkllm` to test that it continues to set the same underlying attribute. updated temperature with Pydantic Field, added to the integration test. Ran `make format`,`make test`, `make lint`, `make spell_check` --- .../langchain_community/chat_models/sparkllm.py | 4 ++-- .../integration_tests/chat_models/test_sparkllm.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index e9efafd0c80..7d122f0a0c8 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -135,14 +135,14 @@ class ChatSparkLLM(BaseChatModel): client: Any = None #: :meta private: spark_app_id: Optional[str] = None - spark_api_key: Optional[str] = None + spark_api_key: Optional[str] = Field(default=None, alias="api_key") spark_api_secret: Optional[str] = None spark_api_url: Optional[str] = None spark_llm_domain: Optional[str] = None spark_user_id: str = "lc_user" streaming: bool = False request_timeout: int = Field(30, alias="timeout") - temperature: float = 0.5 + temperature: float = Field(default=0.5) top_k: int = 4 model_kwargs: Dict[str, Any] = Field(default_factory=dict) diff --git a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py index 1a94af7eb21..848dc487bb8 100644 --- a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py +++ b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py @@ -5,11 +5,21 @@ from langchain_community.chat_models.sparkllm import ChatSparkLLM def test_initialization() -> None: """Test chat model initialization.""" + for model in [ - ChatSparkLLM(timeout=30), - ChatSparkLLM(request_timeout=30), # type: ignore[call-arg] + ChatSparkLLM( + api_key="secret", + temperature=0.5, + timeout=30, + ), + ChatSparkLLM( + spark_api_key="secret", + request_timeout=30, + ), # type: ignore[call-arg] ]: assert model.request_timeout == 30 + assert model.spark_api_key == "secret" + assert model.temperature == 0.5 def test_chat_spark_llm() -> None: