community: Fix the failure of ChatSparkLLM after upgrading to Pydantic V2 (#27418)

**Description:**

The test_sparkllm.py can reproduce this issue.


https://github.com/langchain-ai/langchain/blob/master/libs/community/tests/integration_tests/chat_models/test_sparkllm.py#L66

```
Testing started at 18:27 ...
Launching pytest with arguments test_sparkllm.py::test_chat_spark_llm --no-header --no-summary -q in /Users/zhanglei/Work/github/langchain/libs/community/tests/integration_tests/chat_models

============================= test session starts ==============================
collecting ... collected 1 item

test_sparkllm.py::test_chat_spark_llm 

============================== 1 failed in 0.45s ===============================
FAILED                             [100%]
tests/integration_tests/chat_models/test_sparkllm.py:65 (test_chat_spark_llm)
def test_chat_spark_llm() -> None:
>       chat = ChatSparkLLM(
            spark_app_id="your spark_app_id",
            spark_api_key="your spark_api_key",
            spark_api_secret="your spark_api_secret",
        )  # type: ignore[call-arg]

test_sparkllm.py:67: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../../core/langchain_core/load/serializable.py:111: in __init__
    super().__init__(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cls = <class 'langchain_community.chat_models.sparkllm.ChatSparkLLM'>
values = {'spark_api_key': 'your spark_api_key', 'spark_api_secret': 'your spark_api_secret', 'spark_api_url': 'wss://spark-api.xf-yun.com/v3.5/chat', 'spark_app_id': 'your spark_app_id', ...}

    @model_validator(mode="before")
    @classmethod
    def validate_environment(cls, values: Dict) -> Any:
        values["spark_app_id"] = get_from_dict_or_env(
            values,
            ["spark_app_id", "app_id"],
            "IFLYTEK_SPARK_APP_ID",
        )
        values["spark_api_key"] = get_from_dict_or_env(
            values,
            ["spark_api_key", "api_key"],
            "IFLYTEK_SPARK_API_KEY",
        )
        values["spark_api_secret"] = get_from_dict_or_env(
            values,
            ["spark_api_secret", "api_secret"],
            "IFLYTEK_SPARK_API_SECRET",
        )
        values["spark_api_url"] = get_from_dict_or_env(
            values,
            "spark_api_url",
            "IFLYTEK_SPARK_API_URL",
            SPARK_API_URL,
        )
        values["spark_llm_domain"] = get_from_dict_or_env(
            values,
            "spark_llm_domain",
            "IFLYTEK_SPARK_LLM_DOMAIN",
            SPARK_LLM_DOMAIN,
        )
    
        # put extra params into model_kwargs
        default_values = {
            name: field.default
            for name, field in get_fields(cls).items()
            if field.default is not None
        }
>       values["model_kwargs"]["temperature"] = default_values.get("temperature")
E       KeyError: 'model_kwargs'

../../../langchain_community/chat_models/sparkllm.py:368: KeyError
``` 

I found that when upgrading to Pydantic v2, @root_validator was changed
to @model_validator. When a class declares multiple
@model_validator(model=before), the execution order in V1 and V2 is
opposite. This is the reason for ChatSparkLLM's failure.

The correct execution order is to execute build_extra first.


https://github.com/langchain-ai/langchain/blob/langchain%3D%3D0.2.16/libs/community/langchain_community/chat_models/sparkllm.py#L302

And then execute validate_environment.


https://github.com/langchain-ai/langchain/blob/langchain%3D%3D0.2.16/libs/community/langchain_community/chat_models/sparkllm.py#L329

The Pydantic community also discusses it, but there hasn't been a
conclusion yet. https://github.com/pydantic/pydantic/discussions/7434

**Issus:** #27416 

**Twitter handle:** coolbeevip

---------

Co-authored-by: vbarda <vadym@langchain.dev>
This commit is contained in:
Lei Zhang 2024-10-24 09:17:10 +08:00 committed by GitHub
parent 8f151223ad
commit f203229b51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 28 deletions

View File

@ -95,3 +95,4 @@ xmltodict>=0.13.0,<0.14
nanopq==0.2.1
mlflow[genai]>=2.14.0
databricks-sdk>=0.30.0
websocket>=0.2.1,<1

View File

@ -300,34 +300,6 @@ class ChatSparkLLM(BaseChatModel):
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
@ -378,6 +350,38 @@ class ChatSparkLLM(BaseChatModel):
)
return values
# When using Pydantic V2
# The execution order of multiple @model_validator decorators is opposite to
# their declaration order. https://github.com/pydantic/pydantic/discussions/7434
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
def _stream(
self,
messages: List[BaseMessage],

View File

@ -1,3 +1,4 @@
import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
@ -8,6 +9,7 @@ from langchain_core.output_parsers.openai_tools import (
)
from langchain_community.chat_models.sparkllm import (
ChatSparkLLM,
convert_dict_to_message,
convert_message_to_dict,
)
@ -83,3 +85,25 @@ def test__convert_message_to_dict_system() -> None:
result = convert_message_to_dict(message)
expected_output = {"role": "system", "content": "foo"}
assert result == expected_output
@pytest.mark.requires("websocket")
def test__chat_spark_llm_initialization() -> None:
chat = ChatSparkLLM(
app_id="IFLYTEK_SPARK_APP_ID",
api_key="IFLYTEK_SPARK_API_KEY",
api_secret="IFLYTEK_SPARK_API_SECRET",
api_url="IFLYTEK_SPARK_API_URL",
model="IFLYTEK_SPARK_LLM_DOMAIN",
timeout=40,
temperature=0.1,
top_k=3,
)
assert chat.spark_app_id == "IFLYTEK_SPARK_APP_ID"
assert chat.spark_api_key == "IFLYTEK_SPARK_API_KEY"
assert chat.spark_api_secret == "IFLYTEK_SPARK_API_SECRET"
assert chat.spark_api_url == "IFLYTEK_SPARK_API_URL"
assert chat.spark_llm_domain == "IFLYTEK_SPARK_LLM_DOMAIN"
assert chat.request_timeout == 40
assert chat.temperature == 0.1
assert chat.top_k == 3