mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community[patch]: Use get_fields adapter for pydantic (#25191)
Change all usages of __fields__ with get_fields adapter merged into langchain_core. Code mod generated using the following grit pattern: ``` engine marzano(0.1) language python `$X.__fields__` => `get_fields($X)` where { add_import(source="langchain_core.utils.pydantic", name="get_fields") } ```
This commit is contained in:
parent
663638d6a8
commit
98779797fe
@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Type
|
|||||||
|
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
from langchain_core.tools import BaseToolkit
|
from langchain_core.tools import BaseToolkit
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.tools import BaseTool
|
from langchain_community.tools import BaseTool
|
||||||
from langchain_community.tools.file_management.copy import CopyFileTool
|
from langchain_community.tools.file_management.copy import CopyFileTool
|
||||||
@ -24,7 +25,7 @@ _FILE_TOOLS: List[Type[BaseTool]] = [
|
|||||||
ListDirectoryTool,
|
ListDirectoryTool,
|
||||||
]
|
]
|
||||||
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
|
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
|
||||||
tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS
|
get_fields(tool_cls)["name"].default: tool_cls for tool_cls in _FILE_TOOLS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
|||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
from langchain_core.utils.pydantic import get_fields, is_basemodel_subclass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -397,7 +397,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
|
|
||||||
default_values = {
|
default_values = {
|
||||||
name: field.default
|
name: field.default
|
||||||
for name, field in cls.__fields__.items()
|
for name, field in get_fields(cls).items()
|
||||||
if field.default is not None
|
if field.default is not None
|
||||||
}
|
}
|
||||||
default_values.update(values)
|
default_values.update(values)
|
||||||
|
@ -49,6 +49,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
|||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -399,7 +400,7 @@ class MiniMaxChat(BaseChatModel):
|
|||||||
|
|
||||||
default_values = {
|
default_values = {
|
||||||
name: field.default
|
name: field.default
|
||||||
for name, field in cls.__fields__.items()
|
for name, field in get_fields(cls).items()
|
||||||
if field.default is not None
|
if field.default is not None
|
||||||
}
|
}
|
||||||
default_values.update(values)
|
default_values.update(values)
|
||||||
|
@ -40,6 +40,7 @@ from langchain_core.utils import (
|
|||||||
get_from_dict_or_env,
|
get_from_dict_or_env,
|
||||||
get_pydantic_field_names,
|
get_pydantic_field_names,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -308,7 +309,7 @@ class ChatSparkLLM(BaseChatModel):
|
|||||||
# put extra params into model_kwargs
|
# put extra params into model_kwargs
|
||||||
default_values = {
|
default_values = {
|
||||||
name: field.default
|
name: field.default
|
||||||
for name, field in cls.__fields__.items()
|
for name, field in get_fields(cls).items()
|
||||||
if field.default is not None
|
if field.default is not None
|
||||||
}
|
}
|
||||||
values["model_kwargs"]["temperature"] = default_values.get("temperature")
|
values["model_kwargs"]["temperature"] = default_values.get("temperature")
|
||||||
|
@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.utils import pre_init
|
from langchain_core.utils import pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -88,7 +89,7 @@ class GigaChatEmbeddings(BaseModel, Embeddings):
|
|||||||
"Could not import gigachat python package. "
|
"Could not import gigachat python package. "
|
||||||
"Please install it with `pip install gigachat`."
|
"Please install it with `pip install gigachat`."
|
||||||
)
|
)
|
||||||
fields = set(cls.__fields__.keys())
|
fields = set(get_fields(cls).keys())
|
||||||
diff = set(values.keys()) - fields
|
diff = set(values.keys()) - fields
|
||||||
if diff:
|
if diff:
|
||||||
logger.warning(f"Extra fields {diff} in GigaChat class")
|
logger.warning(f"Extra fields {diff} in GigaChat class")
|
||||||
|
@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ class Banana(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -11,6 +11,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -78,7 +79,7 @@ class Beam(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -6,6 +6,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ class CerebriumAI(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -11,6 +11,7 @@ from langchain_core.callbacks import (
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
from langchain_community.utilities.requests import Requests
|
from langchain_community.utilities.requests import Requests
|
||||||
@ -82,7 +83,7 @@ class EdenAI(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -12,6 +12,7 @@ from langchain_core.language_models.llms import BaseLLM
|
|||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||||
from langchain_core.utils import pre_init
|
from langchain_core.utils import pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import gigachat
|
import gigachat
|
||||||
@ -123,7 +124,7 @@ class _BaseGigaChat(Serializable):
|
|||||||
"Could not import gigachat python package. "
|
"Could not import gigachat python package. "
|
||||||
"Please install it with `pip install gigachat`."
|
"Please install it with `pip install gigachat`."
|
||||||
)
|
)
|
||||||
fields = set(cls.__fields__.keys())
|
fields = set(get_fields(cls).keys())
|
||||||
diff = set(values.keys()) - fields
|
diff = set(values.keys()) - fields
|
||||||
if diff:
|
if diff:
|
||||||
logger.warning(f"Extra fields {diff} in GigaChat class")
|
logger.warning(f"Extra fields {diff} in GigaChat class")
|
||||||
|
@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -68,7 +69,7 @@ class GooseAI(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core.language_models.llms import BaseLLM
|
from langchain_core.language_models.llms import BaseLLM
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.llms import get_type_to_cls_dict
|
from langchain_community.llms import get_type_to_cls_dict
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ def load_llm_from_config(config: dict, **kwargs: Any) -> BaseLLM:
|
|||||||
llm_cls = type_to_cls_dict[config_type]()
|
llm_cls = type_to_cls_dict[config_type]()
|
||||||
|
|
||||||
load_kwargs = {}
|
load_kwargs = {}
|
||||||
if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in llm_cls.__fields__:
|
if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in get_fields(llm_cls):
|
||||||
load_kwargs[_ALLOW_DANGEROUS_DESERIALIZATION_ARG] = kwargs.get(
|
load_kwargs[_ALLOW_DANGEROUS_DESERIALIZATION_ARG] = kwargs.get(
|
||||||
_ALLOW_DANGEROUS_DESERIALIZATION_ARG, False
|
_ALLOW_DANGEROUS_DESERIALIZATION_ARG, False
|
||||||
)
|
)
|
||||||
|
@ -5,6 +5,7 @@ import requests
|
|||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
@ -40,7 +41,7 @@ class Modal(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -34,6 +34,7 @@ from langchain_core.utils import (
|
|||||||
get_pydantic_field_names,
|
get_pydantic_field_names,
|
||||||
pre_init,
|
pre_init,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
from langchain_core.utils.utils import build_extra_kwargs
|
from langchain_core.utils.utils import build_extra_kwargs
|
||||||
|
|
||||||
from langchain_community.utils.openai import is_openai_v1
|
from langchain_community.utils.openai import is_openai_v1
|
||||||
@ -1016,7 +1017,7 @@ class OpenAIChat(BaseLLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
@ -68,7 +69,7 @@ class Petals(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -10,6 +10,7 @@ from langchain_core.pydantic_v1 import (
|
|||||||
root_validator,
|
root_validator,
|
||||||
)
|
)
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
@ -47,7 +48,7 @@ class PipelineAI(LLM, BaseModel):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("pipeline_kwargs", {})
|
extra = values.get("pipeline_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM
|
|||||||
from langchain_core.outputs import GenerationChunk
|
from langchain_core.outputs import GenerationChunk
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from replicate.prediction import Prediction
|
from replicate.prediction import Prediction
|
||||||
@ -75,7 +76,7 @@ class Replicate(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
input = values.pop("input", {})
|
input = values.pop("input", {})
|
||||||
if input:
|
if input:
|
||||||
|
@ -7,6 +7,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
@ -41,7 +42,7 @@ class StochasticAI(LLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in get_fields(cls).values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
|
@ -6,6 +6,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
@ -182,7 +183,7 @@ class DocArrayRetriever(BaseRetriever):
|
|||||||
ValueError: If the document doesn't contain the content field
|
ValueError: If the document doesn't contain the content field
|
||||||
"""
|
"""
|
||||||
|
|
||||||
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
|
fields = doc.keys() if isinstance(doc, dict) else get_fields(doc)
|
||||||
|
|
||||||
if self.content_field not in fields:
|
if self.content_field not in fields:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
from typing_extensions import TYPE_CHECKING, Literal
|
from typing_extensions import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
|
from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
|
||||||
@ -255,7 +256,7 @@ class RedisModel(BaseModel):
|
|||||||
if self.is_empty:
|
if self.is_empty:
|
||||||
return redis_fields
|
return redis_fields
|
||||||
|
|
||||||
for field_name in self.__fields__.keys():
|
for field_name in get_fields(self).keys():
|
||||||
if field_name not in ["content_key", "content_vector_key", "extra"]:
|
if field_name not in ["content_key", "content_vector_key", "extra"]:
|
||||||
field_group = getattr(self, field_name)
|
field_group = getattr(self, field_name)
|
||||||
if field_group is not None:
|
if field_group is not None:
|
||||||
@ -269,7 +270,7 @@ class RedisModel(BaseModel):
|
|||||||
if self.is_empty:
|
if self.is_empty:
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
for field_name in self.__fields__.keys():
|
for field_name in get_fields(self).keys():
|
||||||
field_group = getattr(self, field_name)
|
field_group = getattr(self, field_name)
|
||||||
if field_group is not None:
|
if field_group is not None:
|
||||||
for field in field_group:
|
for field in field_group:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Utils for LLM Tests."""
|
"""Utils for LLM Tests."""
|
||||||
|
|
||||||
from langchain_core.language_models.llms import BaseLLM
|
from langchain_core.language_models.llms import BaseLLM
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
|
|
||||||
def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
|
def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
|
||||||
@ -9,7 +10,7 @@ def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
|
|||||||
assert type(llm) is type(loaded_llm)
|
assert type(llm) is type(loaded_llm)
|
||||||
# Client field can be session based, so hash is different despite
|
# Client field can be session based, so hash is different despite
|
||||||
# all other values being the same, so just assess all other fields
|
# all other values being the same, so just assess all other fields
|
||||||
for field in llm.__fields__.keys():
|
for field in get_fields(llm).keys():
|
||||||
if field != "client" and field != "pipeline":
|
if field != "client" and field != "pipeline":
|
||||||
val = getattr(llm, field)
|
val = getattr(llm, field)
|
||||||
new_val = getattr(loaded_llm, field)
|
new_val = getattr(loaded_llm, field)
|
||||||
|
@ -6,6 +6,7 @@ import pytest
|
|||||||
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
|
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.callbacks import get_openai_callback
|
from langchain_community.callbacks import get_openai_callback
|
||||||
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
|
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
|
||||||
@ -44,7 +45,7 @@ def test_callback_manager_configure_context_vars(
|
|||||||
"completion_tokens": 1,
|
"completion_tokens": 1,
|
||||||
"total_tokens": 3,
|
"total_tokens": 3,
|
||||||
},
|
},
|
||||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
"model_name": get_fields(BaseOpenAI)["model_name"].default,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||||
@ -74,7 +75,7 @@ def test_callback_manager_configure_context_vars(
|
|||||||
"completion_tokens": 1,
|
"completion_tokens": 1,
|
||||||
"total_tokens": 3,
|
"total_tokens": 3,
|
||||||
},
|
},
|
||||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
"model_name": get_fields(BaseOpenAI)["model_name"].default,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||||
|
@ -4,6 +4,7 @@ from uuid import uuid4
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
from langchain_community.callbacks import OpenAICallbackHandler
|
from langchain_community.callbacks import OpenAICallbackHandler
|
||||||
from langchain_community.llms.openai import BaseOpenAI
|
from langchain_community.llms.openai import BaseOpenAI
|
||||||
@ -23,7 +24,7 @@ def test_on_llm_end(handler: OpenAICallbackHandler) -> None:
|
|||||||
"completion_tokens": 1,
|
"completion_tokens": 1,
|
||||||
"total_tokens": 3,
|
"total_tokens": 3,
|
||||||
},
|
},
|
||||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
"model_name": get_fields(BaseOpenAI)["model_name"].default,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
handler.on_llm_end(response)
|
handler.on_llm_end(response)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool, StructuredTool
|
from langchain_core.tools import BaseTool, StructuredTool
|
||||||
|
from langchain_core.utils.pydantic import get_fields
|
||||||
|
|
||||||
import langchain_community.tools
|
import langchain_community.tools
|
||||||
from langchain_community.tools import _DEPRECATED_TOOLS
|
from langchain_community.tools import _DEPRECATED_TOOLS
|
||||||
@ -22,7 +23,7 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT
|
|||||||
if isinstance(tool_class, type) and issubclass(tool_class, BaseTool):
|
if isinstance(tool_class, type) and issubclass(tool_class, BaseTool):
|
||||||
if tool_class in _EXCLUDE:
|
if tool_class in _EXCLUDE:
|
||||||
continue
|
continue
|
||||||
if skip_tools_without_default_names and tool_class.__fields__[
|
if skip_tools_without_default_names and get_fields(tool_class)[
|
||||||
"name"
|
"name"
|
||||||
].default in [ # type: ignore
|
].default in [ # type: ignore
|
||||||
None,
|
None,
|
||||||
@ -36,6 +37,6 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT
|
|||||||
def test_tool_names_unique() -> None:
|
def test_tool_names_unique() -> None:
|
||||||
"""Test that the default names for our core tools are unique."""
|
"""Test that the default names for our core tools are unique."""
|
||||||
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
|
tool_classes = _get_tool_classes(skip_tools_without_default_names=True)
|
||||||
names = sorted([tool_cls.__fields__["name"].default for tool_cls in tool_classes])
|
names = sorted([get_fields(tool_cls)["name"].default for tool_cls in tool_classes])
|
||||||
duplicated_names = [name for name in names if names.count(name) > 1]
|
duplicated_names = [name for name in names if names.count(name) > 1]
|
||||||
assert not duplicated_names
|
assert not duplicated_names
|
||||||
|
Loading…
Reference in New Issue
Block a user