community[patch]: Use get_fields adapter for pydantic (#25191)

Change all usages of __fields__ with get_fields adapter merged into
langchain_core.

Code mod generated using the following grit pattern:

```
engine marzano(0.1)
language python


`$X.__fields__` => `get_fields($X)` where {
    add_import(source="langchain_core.utils.pydantic", name="get_fields")
}
```
This commit is contained in:
Eugene Yurtsev 2024-08-08 14:43:09 -04:00 committed by GitHub
parent 663638d6a8
commit 98779797fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 51 additions and 28 deletions

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Type
from langchain_core.pydantic_v1 import root_validator
from langchain_core.tools import BaseToolkit
from langchain_core.utils.pydantic import get_fields
from langchain_community.tools import BaseTool
from langchain_community.tools.file_management.copy import CopyFileTool
@ -24,7 +25,7 @@ _FILE_TOOLS: List[Type[BaseTool]] = [
ListDirectoryTool,
]
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS
get_fields(tool_cls)["name"].default: tool_cls for tool_cls in _FILE_TOOLS
}

View File

@ -51,7 +51,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.pydantic import get_fields, is_basemodel_subclass
logger = logging.getLogger(__name__)
@ -397,7 +397,7 @@ class QianfanChatEndpoint(BaseChatModel):
default_values = {
name: field.default
for name, field in cls.__fields__.items()
for name, field in get_fields(cls).items()
if field.default is not None
}
default_values.update(values)

View File

@ -49,6 +49,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import get_fields
logger = logging.getLogger(__name__)
@ -399,7 +400,7 @@ class MiniMaxChat(BaseChatModel):
default_values = {
name: field.default
for name, field in cls.__fields__.items()
for name, field in get_fields(cls).items()
if field.default is not None
}
default_values.update(values)

View File

@ -40,6 +40,7 @@ from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.pydantic import get_fields
logger = logging.getLogger(__name__)
@ -308,7 +309,7 @@ class ChatSparkLLM(BaseChatModel):
# put extra params into model_kwargs
default_values = {
name: field.default
for name, field in cls.__fields__.items()
for name, field in get_fields(cls).items()
if field.default is not None
}
values["model_kwargs"]["temperature"] = default_values.get("temperature")

View File

@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils import pre_init
from langchain_core.utils.pydantic import get_fields
logger = logging.getLogger(__name__)
@ -88,7 +89,7 @@ class GigaChatEmbeddings(BaseModel, Embeddings):
"Could not import gigachat python package. "
"Please install it with `pip install gigachat`."
)
fields = set(cls.__fields__.keys())
fields = set(get_fields(cls).keys())
diff = set(values.keys()) - fields
if diff:
logger.warning(f"Extra fields {diff} in GigaChat class")

View File

@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
from langchain_community.llms.utils import enforce_stop_tokens
@ -46,7 +47,7 @@ class Banana(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -11,6 +11,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
logger = logging.getLogger(__name__)
@ -78,7 +79,7 @@ class Beam(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -6,6 +6,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
from langchain_community.llms.utils import enforce_stop_tokens
@ -45,7 +46,7 @@ class CerebriumAI(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -11,6 +11,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
from langchain_community.llms.utils import enforce_stop_tokens
from langchain_community.utilities.requests import Requests
@ -82,7 +83,7 @@ class EdenAI(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -12,6 +12,7 @@ from langchain_core.language_models.llms import BaseLLM
from langchain_core.load.serializable import Serializable
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.utils import pre_init
from langchain_core.utils.pydantic import get_fields
if TYPE_CHECKING:
import gigachat
@ -123,7 +124,7 @@ class _BaseGigaChat(Serializable):
"Could not import gigachat python package. "
"Please install it with `pip install gigachat`."
)
fields = set(cls.__fields__.keys())
fields = set(get_fields(cls).keys())
diff = set(values.keys()) - fields
if diff:
logger.warning(f"Extra fields {diff} in GigaChat class")

View File

@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
logger = logging.getLogger(__name__)
@ -68,7 +69,7 @@ class GooseAI(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -6,6 +6,7 @@ from typing import Any, Union
import yaml
from langchain_core.language_models.llms import BaseLLM
from langchain_core.utils.pydantic import get_fields
from langchain_community.llms import get_type_to_cls_dict
@ -26,7 +27,7 @@ def load_llm_from_config(config: dict, **kwargs: Any) -> BaseLLM:
llm_cls = type_to_cls_dict[config_type]()
load_kwargs = {}
if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in llm_cls.__fields__:
if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in get_fields(llm_cls):
load_kwargs[_ALLOW_DANGEROUS_DESERIALIZATION_ARG] = kwargs.get(
_ALLOW_DANGEROUS_DESERIALIZATION_ARG, False
)

View File

@ -5,6 +5,7 @@ import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils.pydantic import get_fields
from langchain_community.llms.utils import enforce_stop_tokens
@ -40,7 +41,7 @@ class Modal(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -34,6 +34,7 @@ from langchain_core.utils import (
get_pydantic_field_names,
pre_init,
)
from langchain_core.utils.pydantic import get_fields
from langchain_core.utils.utils import build_extra_kwargs
from langchain_community.utils.openai import is_openai_v1
@ -1016,7 +1017,7 @@ class OpenAIChat(BaseLLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
from langchain_community.llms.utils import enforce_stop_tokens
@ -68,7 +69,7 @@ class Petals(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -10,6 +10,7 @@ from langchain_core.pydantic_v1 import (
root_validator,
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
from langchain_community.llms.utils import enforce_stop_tokens
@ -47,7 +48,7 @@ class PipelineAI(LLM, BaseModel):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("pipeline_kwargs", {})
for field_name in list(values):

View File

@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
if TYPE_CHECKING:
from replicate.prediction import Prediction
@ -75,7 +76,7 @@ class Replicate(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
input = values.pop("input", {})
if input:

View File

@ -7,6 +7,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
from langchain_community.llms.utils import enforce_stop_tokens
@ -41,7 +42,7 @@ class StochasticAI(LLM):
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = {field.alias for field in get_fields(cls).values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):

View File

@ -6,6 +6,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils.pydantic import get_fields
from langchain_community.vectorstores.utils import maximal_marginal_relevance
@ -182,7 +183,7 @@ class DocArrayRetriever(BaseRetriever):
ValueError: If the document doesn't contain the content field
"""
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
fields = doc.keys() if isinstance(doc, dict) else get_fields(doc)
if self.content_field not in fields:
raise ValueError(

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np
import yaml
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.utils.pydantic import get_fields
from typing_extensions import TYPE_CHECKING, Literal
from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
@ -255,7 +256,7 @@ class RedisModel(BaseModel):
if self.is_empty:
return redis_fields
for field_name in self.__fields__.keys():
for field_name in get_fields(self).keys():
if field_name not in ["content_key", "content_vector_key", "extra"]:
field_group = getattr(self, field_name)
if field_group is not None:
@ -269,7 +270,7 @@ class RedisModel(BaseModel):
if self.is_empty:
return keys
for field_name in self.__fields__.keys():
for field_name in get_fields(self).keys():
field_group = getattr(self, field_name)
if field_group is not None:
for field in field_group:

View File

@ -1,6 +1,7 @@
"""Utils for LLM Tests."""
from langchain_core.language_models.llms import BaseLLM
from langchain_core.utils.pydantic import get_fields
def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
@ -9,7 +10,7 @@ def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
assert type(llm) is type(loaded_llm)
# Client field can be session based, so hash is different despite
# all other values being the same, so just assess all other fields
for field in llm.__fields__.keys():
for field in get_fields(llm).keys():
if field != "client" and field != "pipeline":
val = getattr(llm, field)
new_val = getattr(loaded_llm, field)

View File

@ -6,6 +6,7 @@ import pytest
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
from langchain_core.outputs import LLMResult
from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers
from langchain_core.utils.pydantic import get_fields
from langchain_community.callbacks import get_openai_callback
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
@ -44,7 +45,7 @@ def test_callback_manager_configure_context_vars(
"completion_tokens": 1,
"total_tokens": 3,
},
"model_name": BaseOpenAI.__fields__["model_name"].default,
"model_name": get_fields(BaseOpenAI)["model_name"].default,
},
)
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
@ -74,7 +75,7 @@ def test_callback_manager_configure_context_vars(
"completion_tokens": 1,
"total_tokens": 3,
},
"model_name": BaseOpenAI.__fields__["model_name"].default,
"model_name": get_fields(BaseOpenAI)["model_name"].default,
},
)
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)

View File

@ -4,6 +4,7 @@ from uuid import uuid4
import numpy as np
import pytest
from langchain_core.outputs import LLMResult
from langchain_core.utils.pydantic import get_fields
from langchain_community.callbacks import OpenAICallbackHandler
from langchain_community.llms.openai import BaseOpenAI
@ -23,7 +24,7 @@ def test_on_llm_end(handler: OpenAICallbackHandler) -> None:
"completion_tokens": 1,
"total_tokens": 3,
},
"model_name": BaseOpenAI.__fields__["model_name"].default,
"model_name": get_fields(BaseOpenAI)["model_name"].default,
},
)
handler.on_llm_end(response)

View File

@ -1,6 +1,7 @@
from typing import List, Type
from langchain_core.tools import BaseTool, StructuredTool
from langchain_core.utils.pydantic import get_fields
import langchain_community.tools
from langchain_community.tools import _DEPRECATED_TOOLS
@ -22,7 +23,7 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT
if isinstance(tool_class, type) and issubclass(tool_class, BaseTool):
if tool_class in _EXCLUDE:
continue
if skip_tools_without_default_names and tool_class.__fields__[
if skip_tools_without_default_names and get_fields(tool_class)[
"name"
].default in [ # type: ignore
None,
@ -36,6 +37,6 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT
def test_tool_names_unique() -> None:
"""Test that the default names for our core tools are unique."""
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
names = sorted([tool_cls.__fields__["name"].default for tool_cls in tool_classes])
names = sorted([get_fields(tool_cls)["name"].default for tool_cls in tool_classes])
duplicated_names = [name for name in names if names.count(name) > 1]
assert not duplicated_names