mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
community[patch]: Use get_fields adapter for pydantic (#25191)
Change all usages of __fields__ with get_fields adapter merged into langchain_core. Code mod generated using the following grit pattern: ``` engine marzano(0.1) language python `$X.__fields__` => `get_fields($X)` where { add_import(source="langchain_core.utils.pydantic", name="get_fields") } ```
This commit is contained in:
parent
663638d6a8
commit
98779797fe
@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseToolkit
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.tools import BaseTool
|
||||
from langchain_community.tools.file_management.copy import CopyFileTool
|
||||
@ -24,7 +25,7 @@ _FILE_TOOLS: List[Type[BaseTool]] = [
|
||||
ListDirectoryTool,
|
||||
]
|
||||
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
|
||||
tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS
|
||||
get_fields(tool_cls)["name"].default: tool_cls for tool_cls in _FILE_TOOLS
|
||||
}
|
||||
|
||||
|
||||
|
@ -51,7 +51,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_core.utils.pydantic import get_fields, is_basemodel_subclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -397,7 +397,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
|
||||
default_values = {
|
||||
name: field.default
|
||||
for name, field in cls.__fields__.items()
|
||||
for name, field in get_fields(cls).items()
|
||||
if field.default is not None
|
||||
}
|
||||
default_values.update(values)
|
||||
|
@ -49,6 +49,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -399,7 +400,7 @@ class MiniMaxChat(BaseChatModel):
|
||||
|
||||
default_values = {
|
||||
name: field.default
|
||||
for name, field in cls.__fields__.items()
|
||||
for name, field in get_fields(cls).items()
|
||||
if field.default is not None
|
||||
}
|
||||
default_values.update(values)
|
||||
|
@ -40,6 +40,7 @@ from langchain_core.utils import (
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -308,7 +309,7 @@ class ChatSparkLLM(BaseChatModel):
|
||||
# put extra params into model_kwargs
|
||||
default_values = {
|
||||
name: field.default
|
||||
for name, field in cls.__fields__.items()
|
||||
for name, field in get_fields(cls).items()
|
||||
if field.default is not None
|
||||
}
|
||||
values["model_kwargs"]["temperature"] = default_values.get("temperature")
|
||||
|
@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils import pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -88,7 +89,7 @@ class GigaChatEmbeddings(BaseModel, Embeddings):
|
||||
"Could not import gigachat python package. "
|
||||
"Please install it with `pip install gigachat`."
|
||||
)
|
||||
fields = set(cls.__fields__.keys())
|
||||
fields = set(get_fields(cls).keys())
|
||||
diff = set(values.keys()) - fields
|
||||
if diff:
|
||||
logger.warning(f"Extra fields {diff} in GigaChat class")
|
||||
|
@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@ -46,7 +47,7 @@ class Banana(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -78,7 +79,7 @@ class Beam(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -6,6 +6,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@ -45,7 +46,7 @@ class CerebriumAI(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
from langchain_community.utilities.requests import Requests
|
||||
@ -82,7 +83,7 @@ class EdenAI(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -12,6 +12,7 @@ from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.utils import pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import gigachat
|
||||
@ -123,7 +124,7 @@ class _BaseGigaChat(Serializable):
|
||||
"Could not import gigachat python package. "
|
||||
"Please install it with `pip install gigachat`."
|
||||
)
|
||||
fields = set(cls.__fields__.keys())
|
||||
fields = set(get_fields(cls).keys())
|
||||
diff = set(values.keys()) - fields
|
||||
if diff:
|
||||
logger.warning(f"Extra fields {diff} in GigaChat class")
|
||||
|
@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -68,7 +69,7 @@ class GooseAI(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -6,6 +6,7 @@ from typing import Any, Union
|
||||
|
||||
import yaml
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms import get_type_to_cls_dict
|
||||
|
||||
@ -26,7 +27,7 @@ def load_llm_from_config(config: dict, **kwargs: Any) -> BaseLLM:
|
||||
llm_cls = type_to_cls_dict[config_type]()
|
||||
|
||||
load_kwargs = {}
|
||||
if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in llm_cls.__fields__:
|
||||
if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in get_fields(llm_cls):
|
||||
load_kwargs[_ALLOW_DANGEROUS_DESERIALIZATION_ARG] = kwargs.get(
|
||||
_ALLOW_DANGEROUS_DESERIALIZATION_ARG, False
|
||||
)
|
||||
|
@ -5,6 +5,7 @@ import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@ -40,7 +41,7 @@ class Modal(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -34,6 +34,7 @@ from langchain_core.utils import (
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
@ -1016,7 +1017,7 @@ class OpenAIChat(BaseLLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@ -68,7 +69,7 @@ class Petals(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.pydantic_v1 import (
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@ -47,7 +48,7 @@ class PipelineAI(LLM, BaseModel):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("pipeline_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from replicate.prediction import Prediction
|
||||
@ -75,7 +76,7 @@ class Replicate(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
input = values.pop("input", {})
|
||||
if input:
|
||||
|
@ -7,6 +7,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@ -41,7 +42,7 @@ class StochasticAI(LLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
|
@ -6,6 +6,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
@ -182,7 +183,7 @@ class DocArrayRetriever(BaseRetriever):
|
||||
ValueError: If the document doesn't contain the content field
|
||||
"""
|
||||
|
||||
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
|
||||
fields = doc.keys() if isinstance(doc, dict) else get_fields(doc)
|
||||
|
||||
if self.content_field not in fields:
|
||||
raise ValueError(
|
||||
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import yaml
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
from typing_extensions import TYPE_CHECKING, Literal
|
||||
|
||||
from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
|
||||
@ -255,7 +256,7 @@ class RedisModel(BaseModel):
|
||||
if self.is_empty:
|
||||
return redis_fields
|
||||
|
||||
for field_name in self.__fields__.keys():
|
||||
for field_name in get_fields(self).keys():
|
||||
if field_name not in ["content_key", "content_vector_key", "extra"]:
|
||||
field_group = getattr(self, field_name)
|
||||
if field_group is not None:
|
||||
@ -269,7 +270,7 @@ class RedisModel(BaseModel):
|
||||
if self.is_empty:
|
||||
return keys
|
||||
|
||||
for field_name in self.__fields__.keys():
|
||||
for field_name in get_fields(self).keys():
|
||||
field_group = getattr(self, field_name)
|
||||
if field_group is not None:
|
||||
for field in field_group:
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Utils for LLM Tests."""
|
||||
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
|
||||
def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
|
||||
@ -9,7 +10,7 @@ def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
|
||||
assert type(llm) is type(loaded_llm)
|
||||
# Client field can be session based, so hash is different despite
|
||||
# all other values being the same, so just assess all other fields
|
||||
for field in llm.__fields__.keys():
|
||||
for field in get_fields(llm).keys():
|
||||
if field != "client" and field != "pipeline":
|
||||
val = getattr(llm, field)
|
||||
new_val = getattr(loaded_llm, field)
|
||||
|
@ -6,6 +6,7 @@ import pytest
|
||||
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
|
||||
@ -44,7 +45,7 @@ def test_callback_manager_configure_context_vars(
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
"model_name": get_fields(BaseOpenAI)["model_name"].default,
|
||||
},
|
||||
)
|
||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||
@ -74,7 +75,7 @@ def test_callback_manager_configure_context_vars(
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
"model_name": get_fields(BaseOpenAI)["model_name"].default,
|
||||
},
|
||||
)
|
||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||
|
@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
import numpy as np
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.callbacks import OpenAICallbackHandler
|
||||
from langchain_community.llms.openai import BaseOpenAI
|
||||
@ -23,7 +24,7 @@ def test_on_llm_end(handler: OpenAICallbackHandler) -> None:
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
"model_name": get_fields(BaseOpenAI)["model_name"].default,
|
||||
},
|
||||
)
|
||||
handler.on_llm_end(response)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import List, Type
|
||||
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
import langchain_community.tools
|
||||
from langchain_community.tools import _DEPRECATED_TOOLS
|
||||
@ -22,7 +23,7 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT
|
||||
if isinstance(tool_class, type) and issubclass(tool_class, BaseTool):
|
||||
if tool_class in _EXCLUDE:
|
||||
continue
|
||||
if skip_tools_without_default_names and tool_class.__fields__[
|
||||
if skip_tools_without_default_names and get_fields(tool_class)[
|
||||
"name"
|
||||
].default in [ # type: ignore
|
||||
None,
|
||||
@ -36,6 +37,6 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT
|
||||
def test_tool_names_unique() -> None:
|
||||
"""Test that the default names for our core tools are unique."""
|
||||
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
|
||||
names = sorted([tool_cls.__fields__["name"].default for tool_cls in tool_classes])
|
||||
names = sorted([get_fields(tool_cls)["name"].default for tool_cls in tool_classes])
|
||||
duplicated_names = [name for name in names if names.count(name) > 1]
|
||||
assert not duplicated_names
|
||||
|
Loading…
Reference in New Issue
Block a user