diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index b2548b22193..56caca04381 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -95,3 +95,4 @@ xmltodict>=0.13.0,<0.14 nanopq==0.2.1 mlflow[genai]>=2.14.0 databricks-sdk>=0.30.0 +websocket>=0.2.1,<1 \ No newline at end of file diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index 996c8c2a2f6..3b7d1d47d29 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -300,34 +300,6 @@ class ChatSparkLLM(BaseChatModel): populate_by_name=True, ) - @model_validator(mode="before") - @classmethod - def build_extra(cls, values: Dict[str, Any]) -> Any: - """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = get_pydantic_field_names(cls) - extra = values.get("model_kwargs", {}) - for field_name in list(values): - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") - if field_name not in all_required_field_names: - logger.warning( - f"""WARNING! {field_name} is not default parameter. - {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) - extra[field_name] = values.pop(field_name) - - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) - if invalid_model_kwargs: - raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - f"Instead they were passed in as part of `model_kwargs` parameter." - ) - - values["model_kwargs"] = extra - - return values - @model_validator(mode="before") @classmethod def validate_environment(cls, values: Dict) -> Any: @@ -378,6 +350,38 @@ class ChatSparkLLM(BaseChatModel): ) return values + # When using Pydantic V2 + # The execution order of multiple @model_validator decorators is opposite to + # their declaration order. https://github.com/pydantic/pydantic/discussions/7434 + + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict[str, Any]) -> Any: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + + return values + def _stream( self, messages: List[BaseMessage], diff --git a/libs/community/tests/unit_tests/chat_models/test_sparkllm.py b/libs/community/tests/unit_tests/chat_models/test_sparkllm.py index 6d7e4cf6aa8..f4c768a75d2 100644 --- a/libs/community/tests/unit_tests/chat_models/test_sparkllm.py +++ b/libs/community/tests/unit_tests/chat_models/test_sparkllm.py @@ -1,3 +1,4 @@ +import pytest from langchain_core.messages import ( AIMessage, HumanMessage, @@ -8,6 +9,7 @@ from langchain_core.output_parsers.openai_tools import ( ) from langchain_community.chat_models.sparkllm import ( + ChatSparkLLM, convert_dict_to_message, convert_message_to_dict, ) @@ -83,3 +85,25 @@ def test__convert_message_to_dict_system() -> None: result = convert_message_to_dict(message) expected_output = {"role": "system", "content": "foo"} assert result == expected_output + + +@pytest.mark.requires("websocket") +def test__chat_spark_llm_initialization() -> None: + chat = ChatSparkLLM( + app_id="IFLYTEK_SPARK_APP_ID", + api_key="IFLYTEK_SPARK_API_KEY", + api_secret="IFLYTEK_SPARK_API_SECRET", + api_url="IFLYTEK_SPARK_API_URL", + model="IFLYTEK_SPARK_LLM_DOMAIN", + timeout=40, + temperature=0.1, + top_k=3, + ) + assert chat.spark_app_id == "IFLYTEK_SPARK_APP_ID" + assert chat.spark_api_key == "IFLYTEK_SPARK_API_KEY" + assert chat.spark_api_secret == "IFLYTEK_SPARK_API_SECRET" + assert chat.spark_api_url == "IFLYTEK_SPARK_API_URL" + assert chat.spark_llm_domain == "IFLYTEK_SPARK_LLM_DOMAIN" + assert chat.request_timeout == 40 + assert chat.temperature == 0.1 + assert chat.top_k == 3