mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 20:41:52 +00:00
pydantic initial changes
This commit is contained in:
parent
275e3b6710
commit
4e98c8aa47
@ -15,7 +15,7 @@ test tests:
|
||||
-u LANGCHAIN_API_KEY \
|
||||
-u LANGSMITH_TRACING \
|
||||
-u LANGCHAIN_PROJECT \
|
||||
uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
uv run --group test pytest -vv -s -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
env \
|
||||
|
@ -23,6 +23,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from langchain_core._api.internal import is_caller_internal
|
||||
@ -39,8 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
# Last Any should be FieldInfoV1 but this leads to circular imports
|
||||
T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
|
||||
T = TypeVar("T", bound=Union[type, Callable[..., Any], FieldInfo])
|
||||
|
||||
|
||||
def _validate_deprecation_params(
|
||||
@ -152,10 +152,6 @@ def deprecated(
|
||||
_package: str = package,
|
||||
) -> T:
|
||||
"""Implementation of the decorator returned by `deprecated`."""
|
||||
from langchain_core.utils.pydantic import ( # type: ignore[attr-defined]
|
||||
FieldInfoV1,
|
||||
FieldInfoV2,
|
||||
)
|
||||
|
||||
def emit_warning() -> None:
|
||||
"""Emit the warning."""
|
||||
@ -228,7 +224,7 @@ def deprecated(
|
||||
)
|
||||
return cast("T", obj)
|
||||
|
||||
elif isinstance(obj, FieldInfoV1):
|
||||
elif isinstance(obj, FieldInfo):
|
||||
wrapped = None
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
@ -240,28 +236,7 @@ def deprecated(
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||
return cast(
|
||||
"T",
|
||||
FieldInfoV1(
|
||||
default=obj.default,
|
||||
default_factory=obj.default_factory,
|
||||
description=new_doc,
|
||||
alias=obj.alias,
|
||||
exclude=obj.exclude,
|
||||
),
|
||||
)
|
||||
|
||||
elif isinstance(obj, FieldInfoV2):
|
||||
wrapped = None
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
if not _name:
|
||||
msg = f"Field {obj} must have a name to be deprecated."
|
||||
raise ValueError(msg)
|
||||
old_doc = obj.description
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||
return cast(
|
||||
"T",
|
||||
FieldInfoV2(
|
||||
FieldInfo(
|
||||
default=obj.default,
|
||||
default_factory=obj.default_factory,
|
||||
description=new_doc,
|
||||
|
@ -75,7 +75,6 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_json_schema,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
@ -625,7 +624,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
params = self.model_dump()
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
@ -1298,7 +1297,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Return type of chat model."""
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
def model_dump(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
@ -1456,9 +1455,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"schema": schema,
|
||||
},
|
||||
)
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[cast("TypeBaseModel", schema)], first_tool_only=True
|
||||
tools=[schema], first_tool_only=True
|
||||
)
|
||||
else:
|
||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
|
@ -528,7 +528,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = ensure_config(config)
|
||||
params = self.dict()
|
||||
params = self.model_dump()
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
@ -598,7 +598,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = ensure_config(config)
|
||||
params = self.dict()
|
||||
params = self.model_dump()
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
@ -941,7 +941,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
] * len(prompts)
|
||||
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
|
||||
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
||||
params = self.dict()
|
||||
params = self.model_dump()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
(
|
||||
@ -1193,7 +1193,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
] * len(prompts)
|
||||
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
|
||||
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
||||
params = self.dict()
|
||||
params = self.model_dump()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
(
|
||||
@ -1400,7 +1400,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
"""Return type of llm."""
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
def model_dump(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
@ -1427,7 +1427,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
prompt_dict = self.model_dump()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with save_path.open("w") as f:
|
||||
|
@ -324,9 +324,9 @@ class BaseOutputParser(
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
def model_dump(self, **kwargs: Any) -> dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
output_parser_dict = super().model_dump(**kwargs)
|
||||
with contextlib.suppress(NotImplementedError):
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
|
@ -4,11 +4,10 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import jsonpatch # type: ignore[import-untyped]
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from pydantic import BaseModel, SkipValidation
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
@ -19,18 +18,6 @@ from langchain_core.utils.json import (
|
||||
parse_json_markdown,
|
||||
parse_partial_json,
|
||||
)
|
||||
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
|
||||
|
||||
if IS_PYDANTIC_V1:
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
|
||||
else:
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
|
||||
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
@ -43,18 +30,16 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
pydantic_object: Annotated[Optional[type[TBaseModel]], SkipValidation()] = None # type: ignore[valid-type]
|
||||
pydantic_object: SkipValidation[Optional[type[BaseModel]]] = None
|
||||
"""The Pydantic object to use for validation.
|
||||
If None, no validation is performed."""
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
||||
def _get_schema(self, pydantic_object: type[BaseModel]) -> dict[str, Any]:
|
||||
if issubclass(pydantic_object, BaseModel):
|
||||
return pydantic_object.model_json_schema()
|
||||
if issubclass(pydantic_object, pydantic.v1.BaseModel):
|
||||
return pydantic_object.schema()
|
||||
return None
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
|
@ -274,10 +274,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
pydantic_schema = self.pydantic_schema[fn_name]
|
||||
else:
|
||||
pydantic_schema = self.pydantic_schema
|
||||
if hasattr(pydantic_schema, "model_validate_json"):
|
||||
pydantic_args = pydantic_schema.model_validate_json(_args)
|
||||
else:
|
||||
pydantic_args = pydantic_schema.parse_raw(_args)
|
||||
return pydantic_args
|
||||
|
||||
|
||||
|
@ -4,9 +4,9 @@ import copy
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import SkipValidation, ValidationError
|
||||
from pydantic import BaseModel, SkipValidation, ValidationError
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
@ -15,7 +15,6 @@ from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -264,7 +263,7 @@ _MAX_TOKENS_ERROR = (
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
tools: Annotated[list[TypeBaseModel], SkipValidation()]
|
||||
tools: SkipValidation[list[type[BaseModel]]]
|
||||
"""The tools to parse."""
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
|
@ -1,44 +1,32 @@
|
||||
"""Output parsers using Pydantic."""
|
||||
|
||||
import json
|
||||
from typing import Annotated, Generic, Optional
|
||||
from typing import Generic, Optional, TypeVar
|
||||
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from pydantic import BaseModel, SkipValidation, ValidationError
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V2,
|
||||
PydanticBaseModel,
|
||||
TBaseModel,
|
||||
)
|
||||
|
||||
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
|
||||
|
||||
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[BaseModelT]):
|
||||
"""Parse an output using a pydantic model."""
|
||||
|
||||
pydantic_object: Annotated[type[TBaseModel], SkipValidation()]
|
||||
pydantic_object: SkipValidation[type[BaseModelT]]
|
||||
"""The pydantic model to parse."""
|
||||
|
||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||
if IS_PYDANTIC_V2:
|
||||
def _parse_obj(self, obj: dict) -> BaseModelT:
|
||||
try:
|
||||
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
||||
if issubclass(self.pydantic_object, BaseModel):
|
||||
return self.pydantic_object.model_validate(obj)
|
||||
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
|
||||
return self.pydantic_object.parse_obj(obj)
|
||||
msg = f"Unsupported model version for PydanticOutputParser: \
|
||||
{self.pydantic_object.__class__}"
|
||||
raise OutputParserException(msg)
|
||||
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
||||
raise self._parser_exception(e, obj) from e
|
||||
else: # pydantic v1
|
||||
try:
|
||||
return self.pydantic_object.parse_obj(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
except ValidationError as e:
|
||||
raise self._parser_exception(e, obj) from e
|
||||
|
||||
def _parser_exception(
|
||||
@ -51,7 +39,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> Optional[TBaseModel]:
|
||||
) -> Optional[BaseModelT]:
|
||||
"""Parse the result of an LLM call to a pydantic object.
|
||||
|
||||
Args:
|
||||
@ -72,7 +60,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return None
|
||||
raise
|
||||
|
||||
def parse(self, text: str) -> TBaseModel:
|
||||
def parse(self, text: str) -> BaseModelT:
|
||||
"""Parse the output of an LLM call to a pydantic object.
|
||||
|
||||
Args:
|
||||
@ -109,7 +97,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[TBaseModel]:
|
||||
def OutputType(self) -> type[BaseModelT]:
|
||||
"""Return the pydantic model."""
|
||||
return self.pydantic_object
|
||||
|
||||
@ -126,7 +114,5 @@ Here is the output schema:
|
||||
|
||||
# Re-exporting types for backwards compatibility
|
||||
__all__ = [
|
||||
"PydanticBaseModel",
|
||||
"PydanticOutputParser",
|
||||
"TBaseModel",
|
||||
]
|
||||
|
@ -125,7 +125,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
message=BaseMessageChunk(**chunk.model_dump())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
@ -151,7 +151,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
message=BaseMessageChunk(**chunk.model_dump())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
@ -331,7 +331,7 @@ class BasePromptTemplate(
|
||||
"""Return the prompt type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
def model_dump(self, **kwargs: Any) -> dict:
|
||||
"""Return dictionary representation of prompt.
|
||||
|
||||
Args:
|
||||
@ -369,7 +369,7 @@ class BasePromptTemplate(
|
||||
raise ValueError(msg)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
prompt_dict = self.model_dump()
|
||||
if "_type" not in prompt_dict:
|
||||
msg = f"Prompt {self} does not support saving."
|
||||
raise NotImplementedError(msg)
|
||||
|
@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Optional,
|
||||
TypedDict,
|
||||
@ -886,7 +885,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
messages: Annotated[list[MessageLike], SkipValidation()]
|
||||
messages: SkipValidation[list[MessageLike]]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
@ -98,7 +98,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
|
||||
raise ImportError(msg) from e
|
||||
env = Environment() # noqa: S701
|
||||
ast = env.parse(template)
|
||||
return meta.find_undeclared_variables(ast)
|
||||
return meta.find_undeclared_variables(ast) # type: ignore[no-untyped-call]
|
||||
|
||||
|
||||
def mustache_formatter(template: str, /, **kwargs: Any) -> str:
|
||||
|
@ -1,45 +0,0 @@
|
||||
"""Pydantic v1 compatibility shim."""
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_core._api.deprecation import warn_deprecated
|
||||
|
||||
# Create namespaces for pydantic v1 and v2.
|
||||
# This code must stay at the top of the file before other modules may
|
||||
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
|
||||
#
|
||||
# This hack is done for the following reasons:
|
||||
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
|
||||
# both dependencies and dependents may be stuck on either version of v1 or v2.
|
||||
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
|
||||
# unambiguously uses either v1 or v2 API.
|
||||
# * This change is easier to roll out and roll back.
|
||||
|
||||
try:
|
||||
from pydantic.v1 import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic import * # type: ignore[assignment,no-redef] # noqa: F403
|
||||
|
||||
|
||||
try:
|
||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||
except metadata.PackageNotFoundError:
|
||||
_PYDANTIC_MAJOR_VERSION = 0
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain_core.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
@ -1,26 +0,0 @@
|
||||
"""Pydantic v1 compatibility shim."""
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
|
||||
try:
|
||||
from pydantic.v1.dataclasses import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic.dataclasses import * # type: ignore[no-redef] # noqa: F403
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain_core.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
@ -1,26 +0,0 @@
|
||||
"""Pydantic v1 compatibility shim."""
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
|
||||
try:
|
||||
from pydantic.v1.main import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic.main import * # type: ignore[assignment,no-redef] # noqa: F403
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain_core.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
@ -19,13 +18,13 @@ from typing import (
|
||||
)
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.runnables.base import Runnable as RunnableType
|
||||
|
||||
|
||||
@ -233,7 +232,7 @@ def node_data_json(
|
||||
"name": node_data_str(node.id, node.data),
|
||||
},
|
||||
}
|
||||
elif inspect.isclass(node.data) and is_basemodel_subclass(node.data):
|
||||
elif isinstance(node.data, type) and issubclass(node.data, BaseModel):
|
||||
json = (
|
||||
{
|
||||
"type": "schema",
|
||||
|
@ -33,9 +33,6 @@ from pydantic import (
|
||||
model_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@ -59,14 +56,7 @@ from langchain_core.utils.function_calling import (
|
||||
_parse_google_docstring,
|
||||
_py_38_safe_origin,
|
||||
)
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
_create_subset_model,
|
||||
get_fields,
|
||||
is_basemodel_subclass,
|
||||
is_pydantic_v1_subclass,
|
||||
is_pydantic_v2_subclass,
|
||||
)
|
||||
from langchain_core.utils.pydantic import _create_subset_model, get_fields
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
@ -165,36 +155,6 @@ def _infer_arg_descriptions(
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool:
|
||||
"""Determine if a type annotation is a Pydantic model."""
|
||||
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
|
||||
try:
|
||||
return issubclass(annotation, base_model_class)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def _function_annotations_are_pydantic_v1(
|
||||
signature: inspect.Signature, func: Callable
|
||||
) -> bool:
|
||||
"""Determine if all Pydantic annotations in a function signature are from V1."""
|
||||
any_v1_annotations = any(
|
||||
_is_pydantic_annotation(parameter.annotation, pydantic_version="v1")
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
any_v2_annotations = any(
|
||||
_is_pydantic_annotation(parameter.annotation, pydantic_version="v2")
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
if any_v1_annotations and any_v2_annotations:
|
||||
msg = (
|
||||
f"Function {func} contains a mix of Pydantic v1 and v2 annotations. "
|
||||
"Only one version of Pydantic annotations per function is supported."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return any_v1_annotations and not any_v2_annotations
|
||||
|
||||
|
||||
class _SchemaConfig:
|
||||
"""Configuration for the pydantic model.
|
||||
|
||||
@ -241,9 +201,6 @@ def create_schema_from_function(
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
|
||||
if _function_annotations_are_pydantic_v1(sig, func):
|
||||
validated = validate_arguments_v1(func, config=_SchemaConfig) # type: ignore[call-overload]
|
||||
else:
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
with warnings.catch_warnings():
|
||||
# We are using deprecated functionality here.
|
||||
@ -321,7 +278,7 @@ class ToolException(Exception): # noqa: N818
|
||||
"""
|
||||
|
||||
|
||||
ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
|
||||
ArgsSchema = Union[type[BaseModel], dict[str, Any]]
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||
@ -361,7 +318,7 @@ class ChildTool(BaseTool):
|
||||
You can provide few-shot examples as a part of the description.
|
||||
"""
|
||||
|
||||
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field(
|
||||
args_schema: SkipValidation[Optional[ArgsSchema]] = Field(
|
||||
default=None, description="The tool schema."
|
||||
)
|
||||
"""Pydantic model class to validate and parse the tool's input arguments.
|
||||
@ -370,8 +327,6 @@ class ChildTool(BaseTool):
|
||||
|
||||
- A subclass of pydantic.BaseModel.
|
||||
or
|
||||
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
|
||||
or
|
||||
- a JSON schema dict
|
||||
"""
|
||||
return_direct: bool = False
|
||||
@ -414,7 +369,7 @@ class ChildTool(BaseTool):
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
|
||||
handle_validation_error: Optional[
|
||||
Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]]
|
||||
Union[bool, str, Callable[[ValidationError], str]]
|
||||
] = False
|
||||
"""Handle the content of the ValidationError thrown."""
|
||||
|
||||
@ -431,7 +386,7 @@ class ChildTool(BaseTool):
|
||||
if (
|
||||
"args_schema" in kwargs
|
||||
and kwargs["args_schema"] is not None
|
||||
and not is_basemodel_subclass(kwargs["args_schema"])
|
||||
and not issubclass(kwargs["args_schema"], BaseModel)
|
||||
and not isinstance(kwargs["args_schema"], dict)
|
||||
):
|
||||
msg = (
|
||||
@ -543,10 +498,7 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
key_ = next(iter(get_fields(input_args).keys()))
|
||||
if hasattr(input_args, "model_validate"):
|
||||
input_args.model_validate({key_: tool_input})
|
||||
else:
|
||||
input_args.parse_obj({key_: tool_input})
|
||||
return tool_input
|
||||
if input_args is not None:
|
||||
if isinstance(input_args, dict):
|
||||
@ -569,24 +521,6 @@ class ChildTool(BaseTool):
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if (
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.parse_obj(tool_input)
|
||||
result_dict = result.dict()
|
||||
else:
|
||||
msg = (
|
||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||
@ -643,7 +577,7 @@ class ChildTool(BaseTool):
|
||||
if (
|
||||
self.args_schema is not None
|
||||
and isinstance(self.args_schema, type)
|
||||
and is_basemodel_subclass(self.args_schema)
|
||||
and issubclass(self.args_schema, BaseModel)
|
||||
and not get_fields(self.args_schema)
|
||||
):
|
||||
# StructuredTool with no args
|
||||
@ -754,7 +688,7 @@ class ChildTool(BaseTool):
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
except (ValidationError, ValidationErrorV1) as e:
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
@ -901,11 +835,9 @@ def _is_tool_call(x: Any) -> bool:
|
||||
|
||||
|
||||
def _handle_validation_error(
|
||||
e: Union[ValidationError, ValidationErrorV1],
|
||||
e: ValidationError,
|
||||
*,
|
||||
flag: Union[
|
||||
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
||||
],
|
||||
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
content = "Tool input validation error"
|
||||
@ -1067,7 +999,7 @@ def _is_injected_arg_type(
|
||||
|
||||
|
||||
def get_all_basemodel_annotations(
|
||||
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
||||
cls: type[BaseModel], *, default_to_bound: bool = True
|
||||
) -> dict[str, type]:
|
||||
"""Get all annotations from a Pydantic BaseModel and its parents.
|
||||
|
||||
@ -1075,58 +1007,17 @@ def get_all_basemodel_annotations(
|
||||
cls: The Pydantic BaseModel class.
|
||||
default_to_bound: Whether to default to the bound of a TypeVar if it exists.
|
||||
"""
|
||||
# cls has no subscript: cls = FooBar
|
||||
if isinstance(cls, type):
|
||||
# Gather pydantic field objects (v2: model_fields / v1: __fields__)
|
||||
fields = getattr(cls, "model_fields", {}) or getattr(cls, "__fields__", {})
|
||||
fields = cls.model_fields
|
||||
alias_map = {field.alias: name for name, field in fields.items() if field.alias}
|
||||
|
||||
annotations: dict[str, type] = {}
|
||||
for name, param in inspect.signature(cls).parameters.items():
|
||||
# Exclude hidden init args added by pydantic Config. For example if
|
||||
# Exclude hidden init args added by pydantic's ConfigDict. For example if
|
||||
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
||||
if fields and name not in fields and name not in alias_map:
|
||||
continue
|
||||
field_name = alias_map.get(name, name)
|
||||
annotations[field_name] = param.annotation
|
||||
orig_bases: tuple = getattr(cls, "__orig_bases__", ())
|
||||
# cls has subscript: cls = FooBar[int]
|
||||
else:
|
||||
annotations = get_all_basemodel_annotations(
|
||||
get_origin(cls), default_to_bound=False
|
||||
)
|
||||
orig_bases = (cls,)
|
||||
|
||||
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
|
||||
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
|
||||
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
|
||||
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
|
||||
# if cls = FooBar[int], orig_bases will contain FooBar[int]
|
||||
for parent in orig_bases:
|
||||
# if class = FooBar inherits from Baz, parent = Baz
|
||||
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
|
||||
annotations.update(
|
||||
get_all_basemodel_annotations(parent, default_to_bound=False)
|
||||
)
|
||||
continue
|
||||
|
||||
parent_origin = get_origin(parent)
|
||||
|
||||
# if class = FooBar inherits from non-pydantic class
|
||||
if not parent_origin:
|
||||
continue
|
||||
|
||||
# if class = FooBar inherits from Baz[str]:
|
||||
# parent = Baz[str],
|
||||
# parent_origin = Baz,
|
||||
# generic_type_vars = (type vars in Baz)
|
||||
# generic_map = {type var in Baz: str}
|
||||
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
|
||||
generic_map = dict(zip(generic_type_vars, get_args(parent)))
|
||||
for field in getattr(parent_origin, "__annotations__", {}):
|
||||
annotations[field] = _replace_type_vars(
|
||||
annotations[field], generic_map, default_to_bound=default_to_bound
|
||||
)
|
||||
|
||||
return {
|
||||
k: _replace_type_vars(v, default_to_bound=default_to_bound)
|
||||
|
@ -7,7 +7,6 @@ from collections.abc import Awaitable
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
@ -15,7 +14,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field, SkipValidation
|
||||
from pydantic import BaseModel, Field, SkipValidation
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
@ -30,7 +29,6 @@ from langchain_core.tools.base import (
|
||||
_get_runnable_config_param,
|
||||
create_schema_from_function,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import ToolCall
|
||||
@ -40,9 +38,7 @@ class StructuredTool(BaseTool):
|
||||
"""Tool that can operate on any number of inputs."""
|
||||
|
||||
description: str = ""
|
||||
args_schema: Annotated[ArgsSchema, SkipValidation()] = Field(
|
||||
..., description="The tool schema."
|
||||
)
|
||||
args_schema: SkipValidation[ArgsSchema] = Field(..., description="The tool schema.")
|
||||
"""The input arguments' schema."""
|
||||
func: Optional[Callable[..., Any]] = None
|
||||
"""The function to run when the tool is called."""
|
||||
@ -196,7 +192,7 @@ class StructuredTool(BaseTool):
|
||||
if description is None and not parse_docstring:
|
||||
description_ = source_function.__doc__ or None
|
||||
if description_ is None and args_schema:
|
||||
if isinstance(args_schema, type) and is_basemodel_subclass(args_schema):
|
||||
if isinstance(args_schema, type) and issubclass(args_schema, BaseModel):
|
||||
description_ = args_schema.__doc__ or None
|
||||
elif isinstance(args_schema, dict):
|
||||
description_ = args_schema.get("description")
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
@ -32,7 +32,7 @@ def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
|
||||
class TracerSessionV1Base(BaseModelV1):
|
||||
"""Base class for TracerSessionV1."""
|
||||
|
||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||
name: Optional[str] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
|
||||
@ -69,8 +69,8 @@ class BaseRun(BaseModelV1):
|
||||
|
||||
uuid: str
|
||||
parent_uuid: Optional[str] = None
|
||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||
end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
execution_order: int
|
||||
child_execution_order: int
|
||||
|
@ -20,13 +20,11 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
||||
|
||||
from langchain_core._api import beta, deprecated
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
@ -150,9 +148,7 @@ def _convert_pydantic_to_openai_function(
|
||||
The function description.
|
||||
"""
|
||||
if hasattr(model, "model_json_schema"):
|
||||
schema = model.model_json_schema() # Pydantic 2
|
||||
elif hasattr(model, "schema"):
|
||||
schema = model.schema() # Pydantic 1
|
||||
schema = model.model_json_schema()
|
||||
else:
|
||||
msg = "Model must be a Pydantic model."
|
||||
raise TypeError(msg)
|
||||
@ -249,6 +245,7 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
|
||||
"type[BaseModel]",
|
||||
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
||||
)
|
||||
print(model)
|
||||
return _convert_pydantic_to_openai_function(model)
|
||||
|
||||
|
||||
@ -336,7 +333,7 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
return _convert_json_schema_to_openai_function(
|
||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
if issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
|
||||
if issubclass(tool.tool_call_schema, BaseModel):
|
||||
return _convert_pydantic_to_openai_function(
|
||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
@ -466,7 +463,7 @@ def convert_to_openai_function(
|
||||
oai_function["description"] = function_copy.pop("description")
|
||||
if function_copy and "properties" in function_copy:
|
||||
oai_function["parameters"] = function_copy
|
||||
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||
oai_function = cast("dict", _convert_pydantic_to_openai_function(function))
|
||||
elif is_typeddict(function):
|
||||
oai_function = cast(
|
||||
|
@ -2,173 +2,39 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import textwrap
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from functools import lru_cache, wraps
|
||||
from types import GenericAlias
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
PydanticDeprecationWarning,
|
||||
RootModel,
|
||||
root_validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
|
||||
from pydantic import (
|
||||
create_model as _create_model_base,
|
||||
)
|
||||
from pydantic.fields import FieldInfo as FieldInfoV2
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaMode,
|
||||
JsonSchemaValue,
|
||||
)
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION_STRING
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core import core_schema
|
||||
|
||||
try:
|
||||
import pydantic
|
||||
|
||||
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
||||
except ImportError:
|
||||
PYDANTIC_VERSION = version.parse("0.0.0")
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
"""DEPRECATED - Get the major version of Pydantic.
|
||||
|
||||
Use PYDANTIC_VERSION.major instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return PYDANTIC_VERSION.major
|
||||
|
||||
|
||||
PYDANTIC_MAJOR_VERSION = PYDANTIC_VERSION.major
|
||||
PYDANTIC_VERSION = version.parse(PYDANTIC_VERSION_STRING)
|
||||
PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
|
||||
|
||||
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
|
||||
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
|
||||
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic.fields import FieldInfo as FieldInfoV1
|
||||
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
TypeBaseModel = type[BaseModel]
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
||||
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
|
||||
def is_pydantic_v1_subclass(cls: type) -> bool:
|
||||
"""Check if the installed Pydantic version is 1.x-like."""
|
||||
if IS_PYDANTIC_V1:
|
||||
return True
|
||||
if IS_PYDANTIC_V2:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(cls, BaseModelV1):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_pydantic_v2_subclass(cls: type) -> bool:
|
||||
"""Check if the installed Pydantic version is 1.x-like."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
return IS_PYDANTIC_V2 and issubclass(cls, BaseModel)
|
||||
|
||||
|
||||
def is_basemodel_subclass(cls: type) -> bool:
|
||||
"""Check if the given class is a subclass of Pydantic BaseModel.
|
||||
|
||||
Check if the given class is a subclass of any of the following:
|
||||
|
||||
* pydantic.BaseModel in Pydantic 1.x
|
||||
* pydantic.BaseModel in Pydantic 2.x
|
||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||
"""
|
||||
# Before we can use issubclass on the cls we need to check if it is a class
|
||||
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
|
||||
return False
|
||||
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
|
||||
if issubclass(cls, BaseModelV1Proper):
|
||||
return True
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(cls, BaseModelV2):
|
||||
return True
|
||||
|
||||
if issubclass(cls, BaseModelV1):
|
||||
return True
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
return False
|
||||
|
||||
|
||||
def is_basemodel_instance(obj: Any) -> bool:
|
||||
"""Check if the given class is an instance of Pydantic BaseModel.
|
||||
|
||||
Check if the given class is an instance of any of the following:
|
||||
|
||||
* pydantic.BaseModel in Pydantic 1.x
|
||||
* pydantic.BaseModel in Pydantic 2.x
|
||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||
"""
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
|
||||
if isinstance(obj, BaseModelV1Proper):
|
||||
return True
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if isinstance(obj, BaseModelV2):
|
||||
return True
|
||||
|
||||
if isinstance(obj, BaseModelV1):
|
||||
return True
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
return False
|
||||
|
||||
|
||||
# How to type hint this?
|
||||
def pre_init(func: Callable) -> Any:
|
||||
@ -180,10 +46,8 @@ def pre_init(func: Callable) -> Any:
|
||||
Returns:
|
||||
Any: The decorated function.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@wraps(func)
|
||||
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Decorator to run a function before model initialization.
|
||||
@ -198,20 +62,7 @@ def pre_init(func: Callable) -> Any:
|
||||
# Insert default values
|
||||
fields = cls.model_fields
|
||||
for name, field_info in fields.items():
|
||||
# Check if allow_population_by_field_name is enabled
|
||||
# If yes, then set the field name to the alias
|
||||
if (
|
||||
hasattr(cls, "Config")
|
||||
and hasattr(cls.Config, "allow_population_by_field_name")
|
||||
and cls.Config.allow_population_by_field_name
|
||||
and field_info.alias in values
|
||||
):
|
||||
values[name] = values.pop(field_info.alias)
|
||||
if (
|
||||
hasattr(cls, "model_config")
|
||||
and cls.model_config.get("populate_by_name")
|
||||
and field_info.alias in values
|
||||
):
|
||||
if cls.model_config.get("populate_by_name") and field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
|
||||
if (
|
||||
@ -241,55 +92,15 @@ class _IgnoreUnserializable(GenerateJsonSchema):
|
||||
return {}
|
||||
|
||||
|
||||
def _create_subset_model_v1(
|
||||
name: str,
|
||||
model: type[BaseModel],
|
||||
field_names: list,
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import create_model
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic.v1 import create_model # type: ignore[no-redef]
|
||||
else:
|
||||
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
fields = {}
|
||||
|
||||
for field_name in field_names:
|
||||
# Using pydantic v1 so can access __fields__ as a dict.
|
||||
field = model.__fields__[field_name] # type: ignore[index]
|
||||
t = (
|
||||
# this isn't perfect but should work for most functions
|
||||
field.outer_type_
|
||||
if field.required and not field.allow_none
|
||||
else Optional[field.outer_type_]
|
||||
)
|
||||
if descriptions and field_name in descriptions:
|
||||
field.field_info.description = descriptions[field_name]
|
||||
fields[field_name] = (t, field.field_info)
|
||||
|
||||
rtn = create_model(name, **fields) # type: ignore[call-overload]
|
||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||
return rtn
|
||||
|
||||
|
||||
def _create_subset_model_v2(
|
||||
name: str,
|
||||
model: type[pydantic.BaseModel],
|
||||
model: type[BaseModel],
|
||||
field_names: list[str],
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> type[pydantic.BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with a subset of the model fields."""
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
descriptions_ = descriptions or {}
|
||||
fields = {}
|
||||
for field_name in field_names:
|
||||
@ -300,7 +111,7 @@ def _create_subset_model_v2(
|
||||
field_info.metadata = field.metadata
|
||||
fields[field_name] = (field.annotation, field_info)
|
||||
|
||||
rtn = create_model( # type: ignore[call-overload]
|
||||
rtn = _create_model_base(
|
||||
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
|
||||
)
|
||||
|
||||
@ -319,38 +130,15 @@ def _create_subset_model_v2(
|
||||
return rtn
|
||||
|
||||
|
||||
# Private functionality to create a subset model that's compatible across
|
||||
# different versions of pydantic.
|
||||
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
|
||||
# However, can't find a way to type hint this.
|
||||
def _create_subset_model(
|
||||
name: str,
|
||||
model: TypeBaseModel,
|
||||
model: type[BaseModel],
|
||||
field_names: list[str],
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create subset model using the same pydantic version as the input model."""
|
||||
if IS_PYDANTIC_V1:
|
||||
return _create_subset_model_v1(
|
||||
name,
|
||||
model,
|
||||
field_names,
|
||||
descriptions=descriptions,
|
||||
fn_description=fn_description,
|
||||
)
|
||||
if IS_PYDANTIC_V2:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(model, BaseModelV1):
|
||||
return _create_subset_model_v1(
|
||||
name,
|
||||
model,
|
||||
field_names,
|
||||
descriptions=descriptions,
|
||||
fn_description=fn_description,
|
||||
)
|
||||
return _create_subset_model_v2(
|
||||
name,
|
||||
model,
|
||||
@ -358,50 +146,18 @@ def _create_subset_model(
|
||||
descriptions=descriptions,
|
||||
fn_description=fn_description,
|
||||
)
|
||||
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
if IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
@overload
|
||||
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
|
||||
|
||||
def get_fields(
|
||||
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
||||
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
|
||||
def get_fields(
|
||||
model: type[BaseModel],
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
try:
|
||||
return model.model_fields
|
||||
|
||||
if hasattr(model, "__fields__"):
|
||||
return model.__fields__ # type: ignore[return-value]
|
||||
except AttributeError as exc:
|
||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||
raise TypeError(msg)
|
||||
raise TypeError(msg) from exc
|
||||
|
||||
elif IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1_
|
||||
|
||||
def get_fields( # type: ignore[no-redef]
|
||||
model: Union[type[BaseModelV1_], BaseModelV1_],
|
||||
) -> dict[str, FieldInfoV1]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
return model.__fields__ # type: ignore[return-value]
|
||||
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
|
||||
_SchemaConfig = ConfigDict(
|
||||
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
||||
@ -458,17 +214,6 @@ def _create_root_model(
|
||||
if default_ is not NO_DEFAULT:
|
||||
base_class_attributes["root"] = default_
|
||||
with warnings.catch_warnings():
|
||||
try:
|
||||
if (
|
||||
isinstance(type_, type)
|
||||
and not isinstance(type_, GenericAlias)
|
||||
and issubclass(type_, BaseModelV1)
|
||||
):
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=PydanticDeprecationWarning
|
||||
)
|
||||
except TypeError:
|
||||
pass
|
||||
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
||||
return cast("type[BaseModel]", custom_root_type)
|
||||
|
||||
@ -545,9 +290,6 @@ _RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
|
||||
|
||||
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
||||
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
remapped = {}
|
||||
for key, value in field_definitions.items():
|
||||
if key.startswith("_") or key in _RESERVED_NAMES:
|
||||
@ -595,7 +337,7 @@ def create_model_v2(
|
||||
root: Type for a root model (RootModel)
|
||||
|
||||
Returns:
|
||||
Type[BaseModel]: The created model.
|
||||
type[BaseModel]: The created model.
|
||||
"""
|
||||
field_definitions = field_definitions or {}
|
||||
|
||||
|
@ -15,10 +15,6 @@ from pydantic import SecretStr
|
||||
from requests import HTTPError, Response
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.utils.pydantic import (
|
||||
is_pydantic_v1_subclass,
|
||||
)
|
||||
|
||||
|
||||
def xor_args(*arg_groups: tuple[str, ...]) -> Callable:
|
||||
"""Validate specified keyword args are mutually exclusive.".
|
||||
@ -206,12 +202,6 @@ def get_pydantic_field_names(pydantic_cls: Any) -> set[str]:
|
||||
set[str]: Field names.
|
||||
"""
|
||||
all_required_field_names = set()
|
||||
if is_pydantic_v1_subclass(pydantic_cls):
|
||||
for field in pydantic_cls.__fields__.values():
|
||||
all_required_field_names.add(field.name)
|
||||
if field.has_alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
else: # Assuming pydantic 2 for now
|
||||
for name, field in pydantic_cls.model_fields.items():
|
||||
all_required_field_names.add(name)
|
||||
if field.alias:
|
||||
|
@ -125,7 +125,7 @@ filterwarnings = [ "ignore::langchain_core._api.beta_decorator.LangChainBetaWarn
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
[tool.ruff.lint.pep8-naming]
|
||||
classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
|
||||
classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"langchain_core/utils/mustache.py" = [ "PLW0603",]
|
||||
|
@ -16,10 +16,6 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V1,
|
||||
IS_PYDANTIC_V2,
|
||||
)
|
||||
|
||||
STREAMED_MESSAGES: list = [
|
||||
AIMessageChunk(content=""),
|
||||
@ -532,87 +528,13 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||
import pydantic
|
||||
def test_parse_with_different_pydantic() -> None:
|
||||
"""Test with BaseModel"""
|
||||
|
||||
class Forecast(pydantic.v1.BaseModel):
|
||||
class Forecast(BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
||||
import pydantic
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast])
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1")
|
||||
def test_parse_with_different_pydantic_1_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 1."""
|
||||
import pydantic
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast])
|
||||
message = AIMessage(
|
||||
content="",
|
||||
|
@ -3,35 +3,23 @@
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import ParrotFakeChatModel
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.utils.pydantic import TBaseModel
|
||||
|
||||
|
||||
class ForecastV2(pydantic.BaseModel):
|
||||
class Forecast(BaseModel):
|
||||
temperature: int
|
||||
f_or_c: Literal["F", "C"]
|
||||
forecast: str
|
||||
|
||||
|
||||
class ForecastV1(V1BaseModel):
|
||||
temperature: int
|
||||
f_or_c: Literal["F", "C"]
|
||||
forecast: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
||||
def test_pydantic_parser_chaining(
|
||||
pydantic_object: TBaseModel,
|
||||
) -> None:
|
||||
def test_pydantic_parser_chaining() -> None:
|
||||
prompt = PromptTemplate(
|
||||
template="""{{
|
||||
"temperature": 20,
|
||||
@ -43,18 +31,17 @@ def test_pydantic_parser_chaining(
|
||||
|
||||
model = ParrotFakeChatModel()
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
|
||||
parser = PydanticOutputParser(pydantic_object=Forecast)
|
||||
chain = prompt | model | parser
|
||||
|
||||
res = chain.invoke({})
|
||||
assert type(res) is pydantic_object
|
||||
assert type(res) is Forecast
|
||||
assert res.f_or_c == "C"
|
||||
assert res.temperature == 20
|
||||
assert res.forecast == "Sunny"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
||||
def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
|
||||
def test_pydantic_parser_validation() -> None:
|
||||
bad_prompt = PromptTemplate(
|
||||
template="""{{
|
||||
"temperature": "oof",
|
||||
@ -66,17 +53,14 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
|
||||
|
||||
model = ParrotFakeChatModel()
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
|
||||
parser = PydanticOutputParser(pydantic_object=Forecast)
|
||||
chain = bad_prompt | model | parser
|
||||
with pytest.raises(OutputParserException):
|
||||
chain.invoke({})
|
||||
|
||||
|
||||
# JSON output parser tests
|
||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
||||
def test_json_parser_chaining(
|
||||
pydantic_object: TBaseModel,
|
||||
) -> None:
|
||||
def test_json_parser_chaining() -> None:
|
||||
prompt = PromptTemplate(
|
||||
template="""{{
|
||||
"temperature": 20,
|
||||
@ -88,7 +72,7 @@ def test_json_parser_chaining(
|
||||
|
||||
model = ParrotFakeChatModel()
|
||||
|
||||
parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type]
|
||||
parser = JsonOutputParser(pydantic_object=Forecast)
|
||||
chain = prompt | model | parser
|
||||
|
||||
res = chain.invoke({})
|
||||
|
@ -1,6 +1,5 @@
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -10,15 +9,14 @@ from langchain_core.load.load import loads
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.prompts.structured import StructuredPrompt
|
||||
from langchain_core.runnables.base import Runnable, RunnableLambda
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
def _fake_runnable(
|
||||
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
|
||||
) -> Union[BaseModel, dict]:
|
||||
if isclass(schema) and is_basemodel_subclass(schema):
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
return schema(name="yo", value=value)
|
||||
params = cast("dict", schema)["parameters"]
|
||||
params = schema["parameters"]
|
||||
return {k: 1 if k != "value" else value for k, v in params.items()}
|
||||
|
||||
|
||||
|
@ -2,8 +2,6 @@ from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
# Function to replace allOf with $ref
|
||||
def replace_all_of_with_ref(schema: Any) -> None:
|
||||
@ -31,8 +29,6 @@ def replace_all_of_with_ref(schema: Any) -> None:
|
||||
def remove_all_none_default(schema: Any) -> None:
|
||||
"""Removing all none defaults.
|
||||
|
||||
Pydantic v1 did not generate these, but Pydantic v2 does.
|
||||
|
||||
The None defaults usually represent **NotRequired** fields, and the None value
|
||||
is actually **incorrect** as a value since the fields do not allow a None value.
|
||||
|
||||
@ -75,13 +71,9 @@ def _remove_enum(obj: Any) -> None:
|
||||
|
||||
def _schema(obj: Any) -> dict:
|
||||
"""Return the schema of the object."""
|
||||
if not is_basemodel_subclass(obj):
|
||||
if not (isinstance(obj, type) and issubclass(obj, BaseModel)):
|
||||
msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
|
||||
raise TypeError(msg)
|
||||
# Remap to old style schema
|
||||
if not hasattr(obj, "model_json_schema"): # V1 model
|
||||
return obj.schema()
|
||||
|
||||
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
|
||||
if "$defs" in schema_:
|
||||
schema_["definitions"] = schema_["$defs"]
|
||||
|
@ -21,9 +21,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core import tools
|
||||
@ -52,7 +50,6 @@ from langchain_core.tools import (
|
||||
tool,
|
||||
)
|
||||
from langchain_core.tools.base import (
|
||||
ArgsSchema,
|
||||
InjectedToolArg,
|
||||
InjectedToolCallId,
|
||||
SchemaAnnotationError,
|
||||
@ -65,8 +62,6 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V1,
|
||||
IS_PYDANTIC_V2,
|
||||
_create_subset_model,
|
||||
create_model_v2,
|
||||
)
|
||||
@ -79,9 +74,7 @@ def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
||||
if isinstance(tool_schema, dict):
|
||||
return tool_schema
|
||||
|
||||
if hasattr(tool_schema, "model_json_schema"):
|
||||
return tool_schema.model_json_schema()
|
||||
return tool_schema.schema()
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@ -106,14 +99,6 @@ class _MockSchema(BaseModel):
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockSchemaV1(BaseModelV1):
|
||||
"""Return the arguments directly."""
|
||||
|
||||
arg1: int
|
||||
arg2: bool
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockStructuredTool(BaseTool):
|
||||
name: str = "structured_api"
|
||||
args_schema: type[BaseModel] = _MockSchema
|
||||
@ -205,13 +190,6 @@ def test_decorator_with_specified_schema() -> None:
|
||||
assert isinstance(tool_func, BaseTool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
@tool(args_schema=cast("ArgsSchema", _MockSchemaV1))
|
||||
def tool_func_v1(*, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func_v1, BaseTool)
|
||||
assert tool_func_v1.args_schema == _MockSchemaV1
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
@ -345,50 +323,6 @@ def test_structured_tool_types_parsed() -> None:
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_structured_tool_types_parsed_pydantic_v1() -> None:
|
||||
"""Test the non-primitive types are correctly passed to structured tools."""
|
||||
|
||||
class SomeBaseModel(BaseModelV1):
|
||||
foo: str
|
||||
|
||||
class AnotherBaseModel(BaseModelV1):
|
||||
bar: str
|
||||
|
||||
@tool
|
||||
def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel:
|
||||
"""Return the arguments directly."""
|
||||
return AnotherBaseModel(bar=some_base_model.foo)
|
||||
|
||||
assert isinstance(structured_tool, StructuredTool)
|
||||
|
||||
expected = AnotherBaseModel(bar="baz")
|
||||
for arg in [
|
||||
SomeBaseModel(foo="baz"),
|
||||
SomeBaseModel(foo="baz").dict(),
|
||||
]:
|
||||
args = {"some_base_model": arg}
|
||||
result = structured_tool.run(args)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_structured_tool_types_parsed_pydantic_mixed() -> None:
|
||||
"""Test handling of tool with mixed Pydantic version arguments."""
|
||||
|
||||
class SomeBaseModel(BaseModelV1):
|
||||
foo: str
|
||||
|
||||
class AnotherBaseModel(BaseModel):
|
||||
bar: str
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
|
||||
@tool
|
||||
def structured_tool(
|
||||
some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel
|
||||
) -> None:
|
||||
"""Return the arguments directly."""
|
||||
|
||||
|
||||
def test_base_tool_inheritance_base_schema() -> None:
|
||||
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
||||
|
||||
@ -867,7 +801,7 @@ def test_validation_error_handling_callable() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
|
||||
def handling(e: Union[ValidationError, ValidationErrorV1]) -> str:
|
||||
def handling(e: ValidationError) -> str:
|
||||
return expected
|
||||
|
||||
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||
@ -884,9 +818,7 @@ def test_validation_error_handling_callable() -> None:
|
||||
],
|
||||
)
|
||||
def test_validation_error_handling_non_validation_error(
|
||||
handler: Union[
|
||||
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
||||
],
|
||||
handler: Union[bool, str, Callable[[ValidationError], str]],
|
||||
) -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
|
||||
@ -932,7 +864,7 @@ async def test_async_validation_error_handling_callable() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
|
||||
def handling(e: Union[ValidationError, ValidationErrorV1]) -> str:
|
||||
def handling(e: ValidationError) -> str:
|
||||
return expected
|
||||
|
||||
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||
@ -949,9 +881,7 @@ async def test_async_validation_error_handling_callable() -> None:
|
||||
],
|
||||
)
|
||||
async def test_async_validation_error_handling_non_validation_error(
|
||||
handler: Union[
|
||||
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
||||
],
|
||||
handler: Union[bool, str, Callable[[ValidationError], str]],
|
||||
) -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
|
||||
@ -1812,38 +1742,18 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
||||
}
|
||||
|
||||
|
||||
def generate_models() -> list[Any]:
|
||||
"""Generate a list of base models depending on the pydantic version."""
|
||||
|
||||
class FooProper(BaseModel):
|
||||
class FooProper(BaseModel):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
return [FooProper]
|
||||
|
||||
|
||||
def generate_backwards_compatible_v1() -> list[Any]:
|
||||
"""Generate a model with pydantic 2 from the v1 namespace."""
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
class FooV1Namespace(BaseModelV1):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
return [FooV1Namespace]
|
||||
|
||||
|
||||
# This generates a list of models that can be used for testing that our APIs
|
||||
# behave well with either pydantic 1 proper,
|
||||
# pydantic v1 from pydantic 2,
|
||||
# or pydantic 2 proper.
|
||||
TEST_MODELS = generate_models() + generate_backwards_compatible_v1()
|
||||
TEST_MODELS = [FooProper]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||
def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||
class SomeTool(BaseTool):
|
||||
args_schema: type[pydantic_model] = pydantic_model
|
||||
args_schema: type[BaseModel] = pydantic_model
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
return "foo"
|
||||
@ -1853,11 +1763,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||
)
|
||||
|
||||
input_schema = tool.get_input_schema()
|
||||
input_json_schema = (
|
||||
input_schema.model_json_schema()
|
||||
if hasattr(input_schema, "model_json_schema")
|
||||
else input_schema.schema()
|
||||
)
|
||||
input_json_schema = input_schema.model_json_schema()
|
||||
assert input_json_schema == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
@ -1882,22 +1788,13 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||
|
||||
|
||||
def test_args_schema_explicitly_typed() -> None:
|
||||
"""This should test that one can type the args schema as a pydantic model.
|
||||
|
||||
Please note that this will test using pydantic 2 even though BaseTool
|
||||
is a pydantic 1 model!
|
||||
"""
|
||||
# Check with whatever pydantic model is passed in and not via v1 namespace
|
||||
from pydantic import BaseModel
|
||||
"""This should test that one can type the args schema as a pydantic model."""
|
||||
|
||||
class Foo(BaseModel):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
class SomeTool(BaseTool):
|
||||
# type ignoring here since we're allowing overriding a type
|
||||
# signature of pydantic.v1.BaseModel with pydantic.BaseModel
|
||||
# for pydantic 2!
|
||||
args_schema: type[BaseModel] = Foo
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
@ -1944,11 +1841,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||
|
||||
args_schema = cast("BaseModel", foo_tool.args_schema)
|
||||
args_json_schema = (
|
||||
args_schema.model_json_schema()
|
||||
if hasattr(args_schema, "model_json_schema")
|
||||
else args_schema.schema()
|
||||
)
|
||||
args_json_schema = args_schema.model_json_schema()
|
||||
assert args_json_schema == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
@ -1960,11 +1853,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
||||
}
|
||||
|
||||
input_schema = foo_tool.get_input_schema()
|
||||
input_json_schema = (
|
||||
input_schema.model_json_schema()
|
||||
if hasattr(input_schema, "model_json_schema")
|
||||
else input_schema.schema()
|
||||
)
|
||||
input_json_schema = input_schema.model_json_schema()
|
||||
assert input_json_schema == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
@ -2020,22 +1909,10 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
|
||||
assert _is_message_content_type(obj) is expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
||||
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
||||
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
||||
def test__get_all_basemodel_annotations_v2() -> None:
|
||||
A = TypeVar("A")
|
||||
|
||||
if use_v1_namespace:
|
||||
from pydantic.v1 import BaseModel as BaseModel1
|
||||
|
||||
class ModelA(BaseModel1, Generic[A], extra="allow"):
|
||||
a: A
|
||||
|
||||
else:
|
||||
from pydantic import BaseModel as BaseModel2
|
||||
from pydantic import ConfigDict
|
||||
|
||||
class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef]
|
||||
class ModelA(BaseModel, Generic[A]):
|
||||
a: A
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
@ -2089,63 +1966,6 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.")
|
||||
def test__get_all_basemodel_annotations_v1() -> None:
|
||||
A = TypeVar("A")
|
||||
|
||||
class ModelA(BaseModel, Generic[A], extra="allow"):
|
||||
a: A
|
||||
|
||||
class ModelB(ModelA[str]):
|
||||
b: Annotated[ModelA[dict[str, Any]], "foo"]
|
||||
|
||||
class Mixin:
|
||||
def foo(self) -> str:
|
||||
return "foo"
|
||||
|
||||
class ModelC(Mixin, ModelB):
|
||||
c: dict
|
||||
|
||||
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict}
|
||||
actual = get_all_basemodel_annotations(ModelC)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"]}
|
||||
actual = get_all_basemodel_annotations(ModelB)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": Any}
|
||||
actual = get_all_basemodel_annotations(ModelA)
|
||||
assert actual == expected
|
||||
|
||||
expected = {"a": int}
|
||||
actual = get_all_basemodel_annotations(ModelA[int])
|
||||
assert actual == expected
|
||||
|
||||
D = TypeVar("D", bound=Union[str, int])
|
||||
|
||||
class ModelD(ModelC, Generic[D]):
|
||||
d: Optional[D]
|
||||
|
||||
expected = {
|
||||
"a": str,
|
||||
"b": Annotated[ModelA[dict[str, Any]], "foo"],
|
||||
"c": dict,
|
||||
"d": Union[str, int, None],
|
||||
}
|
||||
actual = get_all_basemodel_annotations(ModelD)
|
||||
assert actual == expected
|
||||
|
||||
expected = {
|
||||
"a": str,
|
||||
"b": Annotated[ModelA[dict[str, Any]], "foo"],
|
||||
"c": dict,
|
||||
"d": Union[int, None],
|
||||
}
|
||||
actual = get_all_basemodel_annotations(ModelD[int])
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_get_all_basemodel_annotations_aliases() -> None:
|
||||
class CalculatorInput(BaseModel):
|
||||
a: int = Field(description="first number", alias="A")
|
||||
@ -2226,14 +2046,9 @@ def test_create_retriever_tool() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic import Field as FieldV2
|
||||
from pydantic import ValidationError as ValidationErrorV2
|
||||
|
||||
class Foo(BaseModelV2):
|
||||
x: list[int] = FieldV2(
|
||||
def test_tool_args_schema_pydantic_with_metadata() -> None:
|
||||
class Foo(BaseModel):
|
||||
x: list[int] = Field(
|
||||
description="List of integers", min_length=10, max_length=15
|
||||
)
|
||||
|
||||
@ -2260,7 +2075,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||
}
|
||||
|
||||
assert foo.invoke({"x": [0] * 10})
|
||||
with pytest.raises(ValidationErrorV2):
|
||||
with pytest.raises(ValidationError):
|
||||
foo.invoke({"x": [0] * 9})
|
||||
|
||||
|
||||
|
@ -746,7 +746,7 @@ def test_tool_outputs() -> None:
|
||||
[ExtensionsAnnotated, TypingAnnotated],
|
||||
ids=["typing_extensions.Annotated", "typing.Annotated"],
|
||||
)
|
||||
def test__convert_typed_dict_to_openai_function(
|
||||
def test_convert_typed_dict_to_openai_function(
|
||||
typed_dict: TypeAlias, annotated: TypeAlias
|
||||
) -> None:
|
||||
class SubTool(typed_dict): # type: ignore[misc]
|
||||
@ -985,7 +985,7 @@ def test__convert_typed_dict_to_openai_function(
|
||||
@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
|
||||
def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
|
||||
class Tool(typed_dict): # type: ignore[misc]
|
||||
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
|
||||
arg1: typing.MutableSet
|
||||
|
||||
# Error should be raised since we're using v1 code path here
|
||||
with pytest.raises(TypeError):
|
||||
|
@ -3,18 +3,12 @@
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V1,
|
||||
IS_PYDANTIC_V2,
|
||||
PYDANTIC_VERSION,
|
||||
_create_subset_model_v2,
|
||||
create_model_v2,
|
||||
get_fields,
|
||||
is_basemodel_instance,
|
||||
is_basemodel_subclass,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
@ -94,52 +88,6 @@ def test_with_aliases() -> None:
|
||||
assert foo.z == 2
|
||||
|
||||
|
||||
def test_is_basemodel_subclass() -> None:
|
||||
"""Test pydantic."""
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV1Proper)
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV2)
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV1)
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def test_is_basemodel_instance() -> None:
|
||||
"""Test pydantic."""
|
||||
if IS_PYDANTIC_V1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
|
||||
class FooV1(BaseModelV1Proper):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(FooV1(x=5))
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
class Foo(BaseModelV2):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(Foo(x=5))
|
||||
|
||||
class Bar(BaseModelV1):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(Bar(x=5))
|
||||
else:
|
||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||
def test_with_field_metadata() -> None:
|
||||
"""Test pydantic with field metadata."""
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
@ -168,21 +116,7 @@ def test_with_field_metadata() -> None:
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1")
|
||||
def test_fields_pydantic_v1() -> None:
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
fields = get_fields(Foo)
|
||||
assert fields == {"x": Foo.model_fields["x"]}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||
def test_fields_pydantic_v2_proper() -> None:
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
@ -190,17 +124,6 @@ def test_fields_pydantic_v2_proper() -> None:
|
||||
assert fields == {"x": Foo.model_fields["x"]}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
||||
def test_fields_pydantic_v1_from_2() -> None:
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
fields = get_fields(Foo)
|
||||
assert fields == {"x": Foo.__fields__["x"]}
|
||||
|
||||
|
||||
def test_create_model_v2() -> None:
|
||||
"""Test that create model v2 works as expected."""
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from langchain_core import utils
|
||||
from langchain_core.utils import (
|
||||
@ -16,10 +16,6 @@ from langchain_core.utils import (
|
||||
guard_import,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.pydantic import (
|
||||
IS_PYDANTIC_V1,
|
||||
IS_PYDANTIC_V2,
|
||||
)
|
||||
from langchain_core.utils.utils import secret_from_env
|
||||
|
||||
|
||||
@ -214,39 +210,7 @@ def test_guard_import_failure(
|
||||
guard_import(module_name, pip_name=pip_name, package=package)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
||||
def test_get_pydantic_field_names_v1_in_2() -> None:
|
||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
||||
from pydantic.v1 import Field
|
||||
|
||||
class PydanticV1Model(PydanticV1BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticV1Model)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
||||
def test_get_pydantic_field_names_v2_in_2() -> None:
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PydanticModel(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
alias_field: int = Field(alias="aliased_field")
|
||||
|
||||
result = get_pydantic_field_names(PydanticModel)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1")
|
||||
def test_get_pydantic_field_names_v1() -> None:
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
def test_get_pydantic_field_names() -> None:
|
||||
class PydanticModel(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
|
Loading…
Reference in New Issue
Block a user