community[patch]: standardized sparkllm init args (#21633)

Related to #20085 
@baskaryan 

Thank you for contributing to LangChain!

community:sparkllm[patch]: standardized init args

updated `spark_api_key` so that aliased to `api_key`. Added integration
test for `sparkllm` to test that it continues to set the same underlying
attribute.

updated temperature with Pydantic Field, added to the integration test.

Ran `make format`,`make test`, `make lint`, `make spell_check`
This commit is contained in:
Param Singh 2024-05-20 17:11:36 -07:00 committed by GitHub
parent d4359d3de6
commit d07885f8b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 4 deletions

View File

@ -135,14 +135,14 @@ class ChatSparkLLM(BaseChatModel):
client: Any = None #: :meta private:
spark_app_id: Optional[str] = None
spark_api_key: Optional[str] = None
spark_api_key: Optional[str] = Field(default=None, alias="api_key")
spark_api_secret: Optional[str] = None
spark_api_url: Optional[str] = None
spark_llm_domain: Optional[str] = None
spark_user_id: str = "lc_user"
streaming: bool = False
request_timeout: int = Field(30, alias="timeout")
temperature: float = 0.5
temperature: float = Field(default=0.5)
top_k: int = 4
model_kwargs: Dict[str, Any] = Field(default_factory=dict)

View File

@ -5,11 +5,21 @@ from langchain_community.chat_models.sparkllm import ChatSparkLLM
def test_initialization() -> None:
"""Test chat model initialization."""
for model in [
ChatSparkLLM(timeout=30),
ChatSparkLLM(request_timeout=30), # type: ignore[call-arg]
ChatSparkLLM(
api_key="secret",
temperature=0.5,
timeout=30,
),
ChatSparkLLM(
spark_api_key="secret",
request_timeout=30,
), # type: ignore[call-arg]
]:
assert model.request_timeout == 30
assert model.spark_api_key == "secret"
assert model.temperature == 0.5
def test_chat_spark_llm() -> None: