From f203229b513aa0f04704388081a501a7d8773eeb Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 24 Oct 2024 09:17:10 +0800 Subject: [PATCH] community: Fix the failure of ChatSparkLLM after upgrading to Pydantic V2 (#27418) **Description:** The test_sparkllm.py can reproduce this issue. https://github.com/langchain-ai/langchain/blob/master/libs/community/tests/integration_tests/chat_models/test_sparkllm.py#L66 ``` Testing started at 18:27 ... Launching pytest with arguments test_sparkllm.py::test_chat_spark_llm --no-header --no-summary -q in /Users/zhanglei/Work/github/langchain/libs/community/tests/integration_tests/chat_models ============================= test session starts ============================== collecting ... collected 1 item test_sparkllm.py::test_chat_spark_llm ============================== 1 failed in 0.45s =============================== FAILED [100%] tests/integration_tests/chat_models/test_sparkllm.py:65 (test_chat_spark_llm) def test_chat_spark_llm() -> None: > chat = ChatSparkLLM( spark_app_id="your spark_app_id", spark_api_key="your spark_api_key", spark_api_secret="your spark_api_secret", ) # type: ignore[call-arg] test_sparkllm.py:67: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ ../../../../core/langchain_core/load/serializable.py:111: in __init__ super().__init__(*args, **kwargs) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ cls = values = {'spark_api_key': 'your spark_api_key', 'spark_api_secret': 'your spark_api_secret', 'spark_api_url': 'wss://spark-api.xf-yun.com/v3.5/chat', 'spark_app_id': 'your spark_app_id', ...} @model_validator(mode="before") @classmethod def validate_environment(cls, values: Dict) -> Any: values["spark_app_id"] = get_from_dict_or_env( values, ["spark_app_id", "app_id"], "IFLYTEK_SPARK_APP_ID", ) values["spark_api_key"] = get_from_dict_or_env( values, ["spark_api_key", "api_key"], "IFLYTEK_SPARK_API_KEY", ) values["spark_api_secret"] = get_from_dict_or_env( values, ["spark_api_secret", "api_secret"], "IFLYTEK_SPARK_API_SECRET", ) values["spark_api_url"] = get_from_dict_or_env( values, "spark_api_url", "IFLYTEK_SPARK_API_URL", SPARK_API_URL, ) values["spark_llm_domain"] = get_from_dict_or_env( values, "spark_llm_domain", "IFLYTEK_SPARK_LLM_DOMAIN", SPARK_LLM_DOMAIN, ) # put extra params into model_kwargs default_values = { name: field.default for name, field in get_fields(cls).items() if field.default is not None } > values["model_kwargs"]["temperature"] = default_values.get("temperature") E KeyError: 'model_kwargs' ../../../langchain_community/chat_models/sparkllm.py:368: KeyError ``` I found that when upgrading to Pydantic v2, @root_validator was changed to @model_validator. When a class declares multiple @model_validator(model=before), the execution order in V1 and V2 is opposite. This is the reason for ChatSparkLLM's failure. The correct execution order is to execute build_extra first. https://github.com/langchain-ai/langchain/blob/langchain%3D%3D0.2.16/libs/community/langchain_community/chat_models/sparkllm.py#L302 And then execute validate_environment. https://github.com/langchain-ai/langchain/blob/langchain%3D%3D0.2.16/libs/community/langchain_community/chat_models/sparkllm.py#L329 The Pydantic community also discusses it, but there hasn't been a conclusion yet. https://github.com/pydantic/pydantic/discussions/7434 **Issus:** #27416 **Twitter handle:** coolbeevip --------- Co-authored-by: vbarda --- libs/community/extended_testing_deps.txt | 1 + .../chat_models/sparkllm.py | 60 ++++++++++--------- .../unit_tests/chat_models/test_sparkllm.py | 24 ++++++++ 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index b2548b22193..56caca04381 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -95,3 +95,4 @@ xmltodict>=0.13.0,<0.14 nanopq==0.2.1 mlflow[genai]>=2.14.0 databricks-sdk>=0.30.0 +websocket>=0.2.1,<1 \ No newline at end of file diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index 996c8c2a2f6..3b7d1d47d29 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -300,34 +300,6 @@ class ChatSparkLLM(BaseChatModel): populate_by_name=True, ) - @model_validator(mode="before") - @classmethod - def build_extra(cls, values: Dict[str, Any]) -> Any: - """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = get_pydantic_field_names(cls) - extra = values.get("model_kwargs", {}) - for field_name in list(values): - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") - if field_name not in all_required_field_names: - logger.warning( - f"""WARNING! {field_name} is not default parameter. - {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) - extra[field_name] = values.pop(field_name) - - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) - if invalid_model_kwargs: - raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - f"Instead they were passed in as part of `model_kwargs` parameter." - ) - - values["model_kwargs"] = extra - - return values - @model_validator(mode="before") @classmethod def validate_environment(cls, values: Dict) -> Any: @@ -378,6 +350,38 @@ class ChatSparkLLM(BaseChatModel): ) return values + # When using Pydantic V2 + # The execution order of multiple @model_validator decorators is opposite to + # their declaration order. https://github.com/pydantic/pydantic/discussions/7434 + + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict[str, Any]) -> Any: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + + return values + def _stream( self, messages: List[BaseMessage], diff --git a/libs/community/tests/unit_tests/chat_models/test_sparkllm.py b/libs/community/tests/unit_tests/chat_models/test_sparkllm.py index 6d7e4cf6aa8..f4c768a75d2 100644 --- a/libs/community/tests/unit_tests/chat_models/test_sparkllm.py +++ b/libs/community/tests/unit_tests/chat_models/test_sparkllm.py @@ -1,3 +1,4 @@ +import pytest from langchain_core.messages import ( AIMessage, HumanMessage, @@ -8,6 +9,7 @@ from langchain_core.output_parsers.openai_tools import ( ) from langchain_community.chat_models.sparkllm import ( + ChatSparkLLM, convert_dict_to_message, convert_message_to_dict, ) @@ -83,3 +85,25 @@ def test__convert_message_to_dict_system() -> None: result = convert_message_to_dict(message) expected_output = {"role": "system", "content": "foo"} assert result == expected_output + + +@pytest.mark.requires("websocket") +def test__chat_spark_llm_initialization() -> None: + chat = ChatSparkLLM( + app_id="IFLYTEK_SPARK_APP_ID", + api_key="IFLYTEK_SPARK_API_KEY", + api_secret="IFLYTEK_SPARK_API_SECRET", + api_url="IFLYTEK_SPARK_API_URL", + model="IFLYTEK_SPARK_LLM_DOMAIN", + timeout=40, + temperature=0.1, + top_k=3, + ) + assert chat.spark_app_id == "IFLYTEK_SPARK_APP_ID" + assert chat.spark_api_key == "IFLYTEK_SPARK_API_KEY" + assert chat.spark_api_secret == "IFLYTEK_SPARK_API_SECRET" + assert chat.spark_api_url == "IFLYTEK_SPARK_API_URL" + assert chat.spark_llm_domain == "IFLYTEK_SPARK_LLM_DOMAIN" + assert chat.request_timeout == 40 + assert chat.temperature == 0.1 + assert chat.top_k == 3