mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
pydantic initial changes
This commit is contained in:
parent
275e3b6710
commit
4e98c8aa47
@ -15,7 +15,7 @@ test tests:
|
|||||||
-u LANGCHAIN_API_KEY \
|
-u LANGCHAIN_API_KEY \
|
||||||
-u LANGSMITH_TRACING \
|
-u LANGSMITH_TRACING \
|
||||||
-u LANGCHAIN_PROJECT \
|
-u LANGCHAIN_PROJECT \
|
||||||
uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
|
uv run --group test pytest -vv -s -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||||
|
|
||||||
test_watch:
|
test_watch:
|
||||||
env \
|
env \
|
||||||
|
@ -23,6 +23,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from langchain_core._api.internal import is_caller_internal
|
from langchain_core._api.internal import is_caller_internal
|
||||||
@ -39,8 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
|||||||
# PUBLIC API
|
# PUBLIC API
|
||||||
|
|
||||||
|
|
||||||
# Last Any should be FieldInfoV1 but this leads to circular imports
|
T = TypeVar("T", bound=Union[type, Callable[..., Any], FieldInfo])
|
||||||
T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_deprecation_params(
|
def _validate_deprecation_params(
|
||||||
@ -152,10 +152,6 @@ def deprecated(
|
|||||||
_package: str = package,
|
_package: str = package,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Implementation of the decorator returned by `deprecated`."""
|
"""Implementation of the decorator returned by `deprecated`."""
|
||||||
from langchain_core.utils.pydantic import ( # type: ignore[attr-defined]
|
|
||||||
FieldInfoV1,
|
|
||||||
FieldInfoV2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def emit_warning() -> None:
|
def emit_warning() -> None:
|
||||||
"""Emit the warning."""
|
"""Emit the warning."""
|
||||||
@ -228,7 +224,7 @@ def deprecated(
|
|||||||
)
|
)
|
||||||
return cast("T", obj)
|
return cast("T", obj)
|
||||||
|
|
||||||
elif isinstance(obj, FieldInfoV1):
|
elif isinstance(obj, FieldInfo):
|
||||||
wrapped = None
|
wrapped = None
|
||||||
if not _obj_type:
|
if not _obj_type:
|
||||||
_obj_type = "attribute"
|
_obj_type = "attribute"
|
||||||
@ -240,28 +236,7 @@ def deprecated(
|
|||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
||||||
return cast(
|
return cast(
|
||||||
"T",
|
"T",
|
||||||
FieldInfoV1(
|
FieldInfo(
|
||||||
default=obj.default,
|
|
||||||
default_factory=obj.default_factory,
|
|
||||||
description=new_doc,
|
|
||||||
alias=obj.alias,
|
|
||||||
exclude=obj.exclude,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(obj, FieldInfoV2):
|
|
||||||
wrapped = None
|
|
||||||
if not _obj_type:
|
|
||||||
_obj_type = "attribute"
|
|
||||||
if not _name:
|
|
||||||
msg = f"Field {obj} must have a name to be deprecated."
|
|
||||||
raise ValueError(msg)
|
|
||||||
old_doc = obj.description
|
|
||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
|
|
||||||
return cast(
|
|
||||||
"T",
|
|
||||||
FieldInfoV2(
|
|
||||||
default=obj.default,
|
default=obj.default,
|
||||||
default_factory=obj.default_factory,
|
default_factory=obj.default_factory,
|
||||||
description=new_doc,
|
description=new_doc,
|
||||||
|
@ -75,7 +75,6 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_json_schema,
|
convert_to_json_schema,
|
||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import uuid
|
import uuid
|
||||||
@ -625,7 +624,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
params = self.dict()
|
params = self.model_dump()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
return {**params, **kwargs}
|
return {**params, **kwargs}
|
||||||
|
|
||||||
@ -1298,7 +1297,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def dict(self, **kwargs: Any) -> dict:
|
def model_dump(self, **kwargs: Any) -> dict:
|
||||||
"""Return a dictionary of the LLM."""
|
"""Return a dictionary of the LLM."""
|
||||||
starter_dict = dict(self._identifying_params)
|
starter_dict = dict(self._identifying_params)
|
||||||
starter_dict["_type"] = self._llm_type
|
starter_dict["_type"] = self._llm_type
|
||||||
@ -1456,9 +1455,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
"schema": schema,
|
"schema": schema,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[cast("TypeBaseModel", schema)], first_tool_only=True
|
tools=[schema], first_tool_only=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
|
@ -528,7 +528,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
else:
|
else:
|
||||||
prompt = self._convert_input(input).to_string()
|
prompt = self._convert_input(input).to_string()
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
params = self.dict()
|
params = self.model_dump()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
@ -598,7 +598,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
|
|
||||||
prompt = self._convert_input(input).to_string()
|
prompt = self._convert_input(input).to_string()
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
params = self.dict()
|
params = self.model_dump()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
@ -941,7 +941,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
] * len(prompts)
|
] * len(prompts)
|
||||||
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
|
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
|
||||||
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
||||||
params = self.dict()
|
params = self.model_dump()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
(
|
(
|
||||||
@ -1193,7 +1193,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
] * len(prompts)
|
] * len(prompts)
|
||||||
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
|
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
|
||||||
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
||||||
params = self.dict()
|
params = self.model_dump()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
(
|
(
|
||||||
@ -1400,7 +1400,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def dict(self, **kwargs: Any) -> dict:
|
def model_dump(self, **kwargs: Any) -> dict:
|
||||||
"""Return a dictionary of the LLM."""
|
"""Return a dictionary of the LLM."""
|
||||||
starter_dict = dict(self._identifying_params)
|
starter_dict = dict(self._identifying_params)
|
||||||
starter_dict["_type"] = self._llm_type
|
starter_dict["_type"] = self._llm_type
|
||||||
@ -1427,7 +1427,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Fetch dictionary to save
|
# Fetch dictionary to save
|
||||||
prompt_dict = self.dict()
|
prompt_dict = self.model_dump()
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with save_path.open("w") as f:
|
with save_path.open("w") as f:
|
||||||
|
@ -324,9 +324,9 @@ class BaseOutputParser(
|
|||||||
)
|
)
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> dict:
|
def model_dump(self, **kwargs: Any) -> dict:
|
||||||
"""Return dictionary representation of output parser."""
|
"""Return dictionary representation of output parser."""
|
||||||
output_parser_dict = super().dict(**kwargs)
|
output_parser_dict = super().model_dump(**kwargs)
|
||||||
with contextlib.suppress(NotImplementedError):
|
with contextlib.suppress(NotImplementedError):
|
||||||
output_parser_dict["_type"] = self._type
|
output_parser_dict["_type"] = self._type
|
||||||
return output_parser_dict
|
return output_parser_dict
|
||||||
|
@ -4,11 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Annotated, Any, Optional, TypeVar, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
import jsonpatch # type: ignore[import-untyped]
|
import jsonpatch # type: ignore[import-untyped]
|
||||||
import pydantic
|
from pydantic import BaseModel, SkipValidation
|
||||||
from pydantic import SkipValidation
|
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||||
@ -19,18 +18,6 @@ from langchain_core.utils.json import (
|
|||||||
parse_json_markdown,
|
parse_json_markdown,
|
||||||
parse_partial_json,
|
parse_partial_json,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
|
|
||||||
|
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
PydanticBaseModel = pydantic.BaseModel
|
|
||||||
|
|
||||||
else:
|
|
||||||
from pydantic.v1 import BaseModel
|
|
||||||
|
|
||||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
|
||||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
|
||||||
|
|
||||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||||
@ -43,18 +30,16 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
describing the difference between the previous and the current object.
|
describing the difference between the previous and the current object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pydantic_object: Annotated[Optional[type[TBaseModel]], SkipValidation()] = None # type: ignore[valid-type]
|
pydantic_object: SkipValidation[Optional[type[BaseModel]]] = None
|
||||||
"""The Pydantic object to use for validation.
|
"""The Pydantic object to use for validation.
|
||||||
If None, no validation is performed."""
|
If None, no validation is performed."""
|
||||||
|
|
||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
def _get_schema(self, pydantic_object: type[BaseModel]) -> dict[str, Any]:
|
||||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
if issubclass(pydantic_object, BaseModel):
|
||||||
return pydantic_object.model_json_schema()
|
return pydantic_object.model_json_schema()
|
||||||
if issubclass(pydantic_object, pydantic.v1.BaseModel):
|
|
||||||
return pydantic_object.schema()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
|
@ -274,10 +274,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|||||||
pydantic_schema = self.pydantic_schema[fn_name]
|
pydantic_schema = self.pydantic_schema[fn_name]
|
||||||
else:
|
else:
|
||||||
pydantic_schema = self.pydantic_schema
|
pydantic_schema = self.pydantic_schema
|
||||||
if hasattr(pydantic_schema, "model_validate_json"):
|
pydantic_args = pydantic_schema.model_validate_json(_args)
|
||||||
pydantic_args = pydantic_schema.model_validate_json(_args)
|
|
||||||
else:
|
|
||||||
pydantic_args = pydantic_schema.parse_raw(_args)
|
|
||||||
return pydantic_args
|
return pydantic_args
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Annotated, Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import SkipValidation, ValidationError
|
from pydantic import BaseModel, SkipValidation, ValidationError
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||||
@ -15,7 +15,6 @@ from langchain_core.messages.tool import tool_call as create_tool_call
|
|||||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.utils.json import parse_partial_json
|
from langchain_core.utils.json import parse_partial_json
|
||||||
from langchain_core.utils.pydantic import TypeBaseModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -264,7 +263,7 @@ _MAX_TOKENS_ERROR = (
|
|||||||
class PydanticToolsParser(JsonOutputToolsParser):
|
class PydanticToolsParser(JsonOutputToolsParser):
|
||||||
"""Parse tools from OpenAI response."""
|
"""Parse tools from OpenAI response."""
|
||||||
|
|
||||||
tools: Annotated[list[TypeBaseModel], SkipValidation()]
|
tools: SkipValidation[list[type[BaseModel]]]
|
||||||
"""The tools to parse."""
|
"""The tools to parse."""
|
||||||
|
|
||||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||||
|
@ -1,45 +1,33 @@
|
|||||||
"""Output parsers using Pydantic."""
|
"""Output parsers using Pydantic."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Annotated, Generic, Optional
|
from typing import Generic, Optional, TypeVar
|
||||||
|
|
||||||
import pydantic
|
from pydantic import BaseModel, SkipValidation, ValidationError
|
||||||
from pydantic import SkipValidation
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
from langchain_core.utils.pydantic import (
|
|
||||||
IS_PYDANTIC_V2,
|
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
|
||||||
PydanticBaseModel,
|
|
||||||
TBaseModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
class PydanticOutputParser(JsonOutputParser, Generic[BaseModelT]):
|
||||||
"""Parse an output using a pydantic model."""
|
"""Parse an output using a pydantic model."""
|
||||||
|
|
||||||
pydantic_object: Annotated[type[TBaseModel], SkipValidation()]
|
pydantic_object: SkipValidation[type[BaseModelT]]
|
||||||
"""The pydantic model to parse."""
|
"""The pydantic model to parse."""
|
||||||
|
|
||||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
def _parse_obj(self, obj: dict) -> BaseModelT:
|
||||||
if IS_PYDANTIC_V2:
|
try:
|
||||||
try:
|
if issubclass(self.pydantic_object, BaseModel):
|
||||||
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
return self.pydantic_object.model_validate(obj)
|
||||||
return self.pydantic_object.model_validate(obj)
|
msg = f"Unsupported model version for PydanticOutputParser: \
|
||||||
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
|
{self.pydantic_object.__class__}"
|
||||||
return self.pydantic_object.parse_obj(obj)
|
raise OutputParserException(msg)
|
||||||
msg = f"Unsupported model version for PydanticOutputParser: \
|
except ValidationError as e:
|
||||||
{self.pydantic_object.__class__}"
|
raise self._parser_exception(e, obj) from e
|
||||||
raise OutputParserException(msg)
|
|
||||||
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
|
||||||
raise self._parser_exception(e, obj) from e
|
|
||||||
else: # pydantic v1
|
|
||||||
try:
|
|
||||||
return self.pydantic_object.parse_obj(obj)
|
|
||||||
except pydantic.ValidationError as e:
|
|
||||||
raise self._parser_exception(e, obj) from e
|
|
||||||
|
|
||||||
def _parser_exception(
|
def _parser_exception(
|
||||||
self, e: Exception, json_object: dict
|
self, e: Exception, json_object: dict
|
||||||
@ -51,7 +39,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
|
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: list[Generation], *, partial: bool = False
|
self, result: list[Generation], *, partial: bool = False
|
||||||
) -> Optional[TBaseModel]:
|
) -> Optional[BaseModelT]:
|
||||||
"""Parse the result of an LLM call to a pydantic object.
|
"""Parse the result of an LLM call to a pydantic object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -72,7 +60,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
return None
|
return None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def parse(self, text: str) -> TBaseModel:
|
def parse(self, text: str) -> BaseModelT:
|
||||||
"""Parse the output of an LLM call to a pydantic object.
|
"""Parse the output of an LLM call to a pydantic object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -109,7 +97,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def OutputType(self) -> type[TBaseModel]:
|
def OutputType(self) -> type[BaseModelT]:
|
||||||
"""Return the pydantic model."""
|
"""Return the pydantic model."""
|
||||||
return self.pydantic_object
|
return self.pydantic_object
|
||||||
|
|
||||||
@ -126,7 +114,5 @@ Here is the output schema:
|
|||||||
|
|
||||||
# Re-exporting types for backwards compatibility
|
# Re-exporting types for backwards compatibility
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PydanticBaseModel",
|
|
||||||
"PydanticOutputParser",
|
"PydanticOutputParser",
|
||||||
"TBaseModel",
|
|
||||||
]
|
]
|
||||||
|
@ -125,7 +125,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
message=BaseMessageChunk(**chunk.model_dump())
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
@ -151,7 +151,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
message=BaseMessageChunk(**chunk.model_dump())
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
@ -331,7 +331,7 @@ class BasePromptTemplate(
|
|||||||
"""Return the prompt type key."""
|
"""Return the prompt type key."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> dict:
|
def model_dump(self, **kwargs: Any) -> dict:
|
||||||
"""Return dictionary representation of prompt.
|
"""Return dictionary representation of prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -369,7 +369,7 @@ class BasePromptTemplate(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Fetch dictionary to save
|
# Fetch dictionary to save
|
||||||
prompt_dict = self.dict()
|
prompt_dict = self.model_dump()
|
||||||
if "_type" not in prompt_dict:
|
if "_type" not in prompt_dict:
|
||||||
msg = f"Prompt {self} does not support saving."
|
msg = f"Prompt {self} does not support saving."
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
|
||||||
Any,
|
Any,
|
||||||
Optional,
|
Optional,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
@ -886,7 +885,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
messages: Annotated[list[MessageLike], SkipValidation()]
|
messages: SkipValidation[list[MessageLike]]
|
||||||
"""List of messages consisting of either message prompt templates or messages."""
|
"""List of messages consisting of either message prompt templates or messages."""
|
||||||
validate_template: bool = False
|
validate_template: bool = False
|
||||||
"""Whether or not to try validating the template."""
|
"""Whether or not to try validating the template."""
|
||||||
|
@ -98,7 +98,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
|
|||||||
raise ImportError(msg) from e
|
raise ImportError(msg) from e
|
||||||
env = Environment() # noqa: S701
|
env = Environment() # noqa: S701
|
||||||
ast = env.parse(template)
|
ast = env.parse(template)
|
||||||
return meta.find_undeclared_variables(ast)
|
return meta.find_undeclared_variables(ast) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
|
|
||||||
def mustache_formatter(template: str, /, **kwargs: Any) -> str:
|
def mustache_formatter(template: str, /, **kwargs: Any) -> str:
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
"""Pydantic v1 compatibility shim."""
|
|
||||||
|
|
||||||
from importlib import metadata
|
|
||||||
|
|
||||||
from langchain_core._api.deprecation import warn_deprecated
|
|
||||||
|
|
||||||
# Create namespaces for pydantic v1 and v2.
|
|
||||||
# This code must stay at the top of the file before other modules may
|
|
||||||
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
|
|
||||||
#
|
|
||||||
# This hack is done for the following reasons:
|
|
||||||
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
|
|
||||||
# both dependencies and dependents may be stuck on either version of v1 or v2.
|
|
||||||
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
|
|
||||||
# unambiguously uses either v1 or v2 API.
|
|
||||||
# * This change is easier to roll out and roll back.
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pydantic.v1 import * # noqa: F403
|
|
||||||
except ImportError:
|
|
||||||
from pydantic import * # type: ignore[assignment,no-redef] # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
|
||||||
except metadata.PackageNotFoundError:
|
|
||||||
_PYDANTIC_MAJOR_VERSION = 0
|
|
||||||
|
|
||||||
warn_deprecated(
|
|
||||||
"0.3.0",
|
|
||||||
removal="1.0.0",
|
|
||||||
alternative="pydantic.v1 or pydantic",
|
|
||||||
message=(
|
|
||||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
|
||||||
"The langchain_core.pydantic_v1 module was a "
|
|
||||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
|
||||||
"Please update the code to import from Pydantic directly.\n\n"
|
|
||||||
"For example, replace imports like: "
|
|
||||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
|
||||||
"with: `from pydantic import BaseModel`\n"
|
|
||||||
"or the v1 compatibility namespace if you are working in a code base "
|
|
||||||
"that has not been fully upgraded to pydantic 2 yet. "
|
|
||||||
"\tfrom pydantic.v1 import BaseModel\n"
|
|
||||||
),
|
|
||||||
)
|
|
@ -1,26 +0,0 @@
|
|||||||
"""Pydantic v1 compatibility shim."""
|
|
||||||
|
|
||||||
from langchain_core._api import warn_deprecated
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pydantic.v1.dataclasses import * # noqa: F403
|
|
||||||
except ImportError:
|
|
||||||
from pydantic.dataclasses import * # type: ignore[no-redef] # noqa: F403
|
|
||||||
|
|
||||||
warn_deprecated(
|
|
||||||
"0.3.0",
|
|
||||||
removal="1.0.0",
|
|
||||||
alternative="pydantic.v1 or pydantic",
|
|
||||||
message=(
|
|
||||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
|
||||||
"The langchain_core.pydantic_v1 module was a "
|
|
||||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
|
||||||
"Please update the code to import from Pydantic directly.\n\n"
|
|
||||||
"For example, replace imports like: "
|
|
||||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
|
||||||
"with: `from pydantic import BaseModel`\n"
|
|
||||||
"or the v1 compatibility namespace if you are working in a code base "
|
|
||||||
"that has not been fully upgraded to pydantic 2 yet. "
|
|
||||||
"\tfrom pydantic.v1 import BaseModel\n"
|
|
||||||
),
|
|
||||||
)
|
|
@ -1,26 +0,0 @@
|
|||||||
"""Pydantic v1 compatibility shim."""
|
|
||||||
|
|
||||||
from langchain_core._api import warn_deprecated
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pydantic.v1.main import * # noqa: F403
|
|
||||||
except ImportError:
|
|
||||||
from pydantic.main import * # type: ignore[assignment,no-redef] # noqa: F403
|
|
||||||
|
|
||||||
warn_deprecated(
|
|
||||||
"0.3.0",
|
|
||||||
removal="1.0.0",
|
|
||||||
alternative="pydantic.v1 or pydantic",
|
|
||||||
message=(
|
|
||||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
|
||||||
"The langchain_core.pydantic_v1 module was a "
|
|
||||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
|
||||||
"Please update the code to import from Pydantic directly.\n\n"
|
|
||||||
"For example, replace imports like: "
|
|
||||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
|
||||||
"with: `from pydantic import BaseModel`\n"
|
|
||||||
"or the v1 compatibility namespace if you are working in a code base "
|
|
||||||
"that has not been fully upgraded to pydantic 2 yet. "
|
|
||||||
"\tfrom pydantic.v1 import BaseModel\n"
|
|
||||||
),
|
|
||||||
)
|
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -19,13 +18,13 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain_core.utils.pydantic import _IgnoreUnserializable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain_core.runnables.base import Runnable as RunnableType
|
from langchain_core.runnables.base import Runnable as RunnableType
|
||||||
|
|
||||||
|
|
||||||
@ -233,7 +232,7 @@ def node_data_json(
|
|||||||
"name": node_data_str(node.id, node.data),
|
"name": node_data_str(node.id, node.data),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
elif inspect.isclass(node.data) and is_basemodel_subclass(node.data):
|
elif isinstance(node.data, type) and issubclass(node.data, BaseModel):
|
||||||
json = (
|
json = (
|
||||||
{
|
{
|
||||||
"type": "schema",
|
"type": "schema",
|
||||||
|
@ -33,9 +33,6 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
validate_arguments,
|
validate_arguments,
|
||||||
)
|
)
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
|
||||||
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
@ -59,14 +56,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
_parse_google_docstring,
|
_parse_google_docstring,
|
||||||
_py_38_safe_origin,
|
_py_38_safe_origin,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import _create_subset_model, get_fields
|
||||||
TypeBaseModel,
|
|
||||||
_create_subset_model,
|
|
||||||
get_fields,
|
|
||||||
is_basemodel_subclass,
|
|
||||||
is_pydantic_v1_subclass,
|
|
||||||
is_pydantic_v2_subclass,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import uuid
|
import uuid
|
||||||
@ -165,36 +155,6 @@ def _infer_arg_descriptions(
|
|||||||
return description, arg_descriptions
|
return description, arg_descriptions
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool:
|
|
||||||
"""Determine if a type annotation is a Pydantic model."""
|
|
||||||
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
|
|
||||||
try:
|
|
||||||
return issubclass(annotation, base_model_class)
|
|
||||||
except TypeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _function_annotations_are_pydantic_v1(
|
|
||||||
signature: inspect.Signature, func: Callable
|
|
||||||
) -> bool:
|
|
||||||
"""Determine if all Pydantic annotations in a function signature are from V1."""
|
|
||||||
any_v1_annotations = any(
|
|
||||||
_is_pydantic_annotation(parameter.annotation, pydantic_version="v1")
|
|
||||||
for parameter in signature.parameters.values()
|
|
||||||
)
|
|
||||||
any_v2_annotations = any(
|
|
||||||
_is_pydantic_annotation(parameter.annotation, pydantic_version="v2")
|
|
||||||
for parameter in signature.parameters.values()
|
|
||||||
)
|
|
||||||
if any_v1_annotations and any_v2_annotations:
|
|
||||||
msg = (
|
|
||||||
f"Function {func} contains a mix of Pydantic v1 and v2 annotations. "
|
|
||||||
"Only one version of Pydantic annotations per function is supported."
|
|
||||||
)
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
return any_v1_annotations and not any_v2_annotations
|
|
||||||
|
|
||||||
|
|
||||||
class _SchemaConfig:
|
class _SchemaConfig:
|
||||||
"""Configuration for the pydantic model.
|
"""Configuration for the pydantic model.
|
||||||
|
|
||||||
@ -241,16 +201,13 @@ def create_schema_from_function(
|
|||||||
"""
|
"""
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
if _function_annotations_are_pydantic_v1(sig, func):
|
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||||
validated = validate_arguments_v1(func, config=_SchemaConfig) # type: ignore[call-overload]
|
with warnings.catch_warnings():
|
||||||
else:
|
# We are using deprecated functionality here.
|
||||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
# This code should be re-written to simply construct a pydantic model
|
||||||
with warnings.catch_warnings():
|
# using inspect.signature and create_model.
|
||||||
# We are using deprecated functionality here.
|
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
||||||
# This code should be re-written to simply construct a pydantic model
|
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
|
||||||
# using inspect.signature and create_model.
|
|
||||||
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
|
||||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
|
|
||||||
|
|
||||||
# Let's ignore `self` and `cls` arguments for class and instance methods
|
# Let's ignore `self` and `cls` arguments for class and instance methods
|
||||||
# If qualified name has a ".", then it likely belongs in a class namespace
|
# If qualified name has a ".", then it likely belongs in a class namespace
|
||||||
@ -321,7 +278,7 @@ class ToolException(Exception): # noqa: N818
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
|
ArgsSchema = Union[type[BaseModel], dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||||
@ -361,7 +318,7 @@ class ChildTool(BaseTool):
|
|||||||
You can provide few-shot examples as a part of the description.
|
You can provide few-shot examples as a part of the description.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field(
|
args_schema: SkipValidation[Optional[ArgsSchema]] = Field(
|
||||||
default=None, description="The tool schema."
|
default=None, description="The tool schema."
|
||||||
)
|
)
|
||||||
"""Pydantic model class to validate and parse the tool's input arguments.
|
"""Pydantic model class to validate and parse the tool's input arguments.
|
||||||
@ -370,8 +327,6 @@ class ChildTool(BaseTool):
|
|||||||
|
|
||||||
- A subclass of pydantic.BaseModel.
|
- A subclass of pydantic.BaseModel.
|
||||||
or
|
or
|
||||||
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
|
|
||||||
or
|
|
||||||
- a JSON schema dict
|
- a JSON schema dict
|
||||||
"""
|
"""
|
||||||
return_direct: bool = False
|
return_direct: bool = False
|
||||||
@ -414,7 +369,7 @@ class ChildTool(BaseTool):
|
|||||||
"""Handle the content of the ToolException thrown."""
|
"""Handle the content of the ToolException thrown."""
|
||||||
|
|
||||||
handle_validation_error: Optional[
|
handle_validation_error: Optional[
|
||||||
Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]]
|
Union[bool, str, Callable[[ValidationError], str]]
|
||||||
] = False
|
] = False
|
||||||
"""Handle the content of the ValidationError thrown."""
|
"""Handle the content of the ValidationError thrown."""
|
||||||
|
|
||||||
@ -431,7 +386,7 @@ class ChildTool(BaseTool):
|
|||||||
if (
|
if (
|
||||||
"args_schema" in kwargs
|
"args_schema" in kwargs
|
||||||
and kwargs["args_schema"] is not None
|
and kwargs["args_schema"] is not None
|
||||||
and not is_basemodel_subclass(kwargs["args_schema"])
|
and not issubclass(kwargs["args_schema"], BaseModel)
|
||||||
and not isinstance(kwargs["args_schema"], dict)
|
and not isinstance(kwargs["args_schema"], dict)
|
||||||
):
|
):
|
||||||
msg = (
|
msg = (
|
||||||
@ -543,10 +498,7 @@ class ChildTool(BaseTool):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
key_ = next(iter(get_fields(input_args).keys()))
|
key_ = next(iter(get_fields(input_args).keys()))
|
||||||
if hasattr(input_args, "model_validate"):
|
input_args.model_validate({key_: tool_input})
|
||||||
input_args.model_validate({key_: tool_input})
|
|
||||||
else:
|
|
||||||
input_args.parse_obj({key_: tool_input})
|
|
||||||
return tool_input
|
return tool_input
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
if isinstance(input_args, dict):
|
if isinstance(input_args, dict):
|
||||||
@ -569,24 +521,6 @@ class ChildTool(BaseTool):
|
|||||||
tool_input[k] = tool_call_id
|
tool_input[k] = tool_call_id
|
||||||
result = input_args.model_validate(tool_input)
|
result = input_args.model_validate(tool_input)
|
||||||
result_dict = result.model_dump()
|
result_dict = result.model_dump()
|
||||||
elif issubclass(input_args, BaseModelV1):
|
|
||||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
|
||||||
if (
|
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
|
||||||
and k not in tool_input
|
|
||||||
):
|
|
||||||
if tool_call_id is None:
|
|
||||||
msg = (
|
|
||||||
"When tool includes an InjectedToolCallId "
|
|
||||||
"argument, tool must always be invoked with a full "
|
|
||||||
"model ToolCall of the form: {'args': {...}, "
|
|
||||||
"'name': '...', 'type': 'tool_call', "
|
|
||||||
"'tool_call_id': '...'}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
tool_input[k] = tool_call_id
|
|
||||||
result = input_args.parse_obj(tool_input)
|
|
||||||
result_dict = result.dict()
|
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||||
@ -643,7 +577,7 @@ class ChildTool(BaseTool):
|
|||||||
if (
|
if (
|
||||||
self.args_schema is not None
|
self.args_schema is not None
|
||||||
and isinstance(self.args_schema, type)
|
and isinstance(self.args_schema, type)
|
||||||
and is_basemodel_subclass(self.args_schema)
|
and issubclass(self.args_schema, BaseModel)
|
||||||
and not get_fields(self.args_schema)
|
and not get_fields(self.args_schema)
|
||||||
):
|
):
|
||||||
# StructuredTool with no args
|
# StructuredTool with no args
|
||||||
@ -754,7 +688,7 @@ class ChildTool(BaseTool):
|
|||||||
content, artifact = response
|
content, artifact = response
|
||||||
else:
|
else:
|
||||||
content = response
|
content = response
|
||||||
except (ValidationError, ValidationErrorV1) as e:
|
except ValidationError as e:
|
||||||
if not self.handle_validation_error:
|
if not self.handle_validation_error:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
else:
|
||||||
@ -901,11 +835,9 @@ def _is_tool_call(x: Any) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _handle_validation_error(
|
def _handle_validation_error(
|
||||||
e: Union[ValidationError, ValidationErrorV1],
|
e: ValidationError,
|
||||||
*,
|
*,
|
||||||
flag: Union[
|
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
|
||||||
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
|
||||||
],
|
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(flag, bool):
|
if isinstance(flag, bool):
|
||||||
content = "Tool input validation error"
|
content = "Tool input validation error"
|
||||||
@ -1067,7 +999,7 @@ def _is_injected_arg_type(
|
|||||||
|
|
||||||
|
|
||||||
def get_all_basemodel_annotations(
|
def get_all_basemodel_annotations(
|
||||||
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
cls: type[BaseModel], *, default_to_bound: bool = True
|
||||||
) -> dict[str, type]:
|
) -> dict[str, type]:
|
||||||
"""Get all annotations from a Pydantic BaseModel and its parents.
|
"""Get all annotations from a Pydantic BaseModel and its parents.
|
||||||
|
|
||||||
@ -1075,58 +1007,17 @@ def get_all_basemodel_annotations(
|
|||||||
cls: The Pydantic BaseModel class.
|
cls: The Pydantic BaseModel class.
|
||||||
default_to_bound: Whether to default to the bound of a TypeVar if it exists.
|
default_to_bound: Whether to default to the bound of a TypeVar if it exists.
|
||||||
"""
|
"""
|
||||||
# cls has no subscript: cls = FooBar
|
fields = cls.model_fields
|
||||||
if isinstance(cls, type):
|
alias_map = {field.alias: name for name, field in fields.items() if field.alias}
|
||||||
# Gather pydantic field objects (v2: model_fields / v1: __fields__)
|
|
||||||
fields = getattr(cls, "model_fields", {}) or getattr(cls, "__fields__", {})
|
|
||||||
alias_map = {field.alias: name for name, field in fields.items() if field.alias}
|
|
||||||
|
|
||||||
annotations: dict[str, type] = {}
|
annotations: dict[str, type] = {}
|
||||||
for name, param in inspect.signature(cls).parameters.items():
|
for name, param in inspect.signature(cls).parameters.items():
|
||||||
# Exclude hidden init args added by pydantic Config. For example if
|
# Exclude hidden init args added by pydantic's ConfigDict. For example if
|
||||||
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
||||||
if fields and name not in fields and name not in alias_map:
|
if fields and name not in fields and name not in alias_map:
|
||||||
continue
|
continue
|
||||||
field_name = alias_map.get(name, name)
|
field_name = alias_map.get(name, name)
|
||||||
annotations[field_name] = param.annotation
|
annotations[field_name] = param.annotation
|
||||||
orig_bases: tuple = getattr(cls, "__orig_bases__", ())
|
|
||||||
# cls has subscript: cls = FooBar[int]
|
|
||||||
else:
|
|
||||||
annotations = get_all_basemodel_annotations(
|
|
||||||
get_origin(cls), default_to_bound=False
|
|
||||||
)
|
|
||||||
orig_bases = (cls,)
|
|
||||||
|
|
||||||
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
|
|
||||||
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
|
|
||||||
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
|
|
||||||
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
|
|
||||||
# if cls = FooBar[int], orig_bases will contain FooBar[int]
|
|
||||||
for parent in orig_bases:
|
|
||||||
# if class = FooBar inherits from Baz, parent = Baz
|
|
||||||
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
|
|
||||||
annotations.update(
|
|
||||||
get_all_basemodel_annotations(parent, default_to_bound=False)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
parent_origin = get_origin(parent)
|
|
||||||
|
|
||||||
# if class = FooBar inherits from non-pydantic class
|
|
||||||
if not parent_origin:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if class = FooBar inherits from Baz[str]:
|
|
||||||
# parent = Baz[str],
|
|
||||||
# parent_origin = Baz,
|
|
||||||
# generic_type_vars = (type vars in Baz)
|
|
||||||
# generic_map = {type var in Baz: str}
|
|
||||||
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
|
|
||||||
generic_map = dict(zip(generic_type_vars, get_args(parent)))
|
|
||||||
for field in getattr(parent_origin, "__annotations__", {}):
|
|
||||||
annotations[field] = _replace_type_vars(
|
|
||||||
annotations[field], generic_map, default_to_bound=default_to_bound
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
k: _replace_type_vars(v, default_to_bound=default_to_bound)
|
k: _replace_type_vars(v, default_to_bound=default_to_bound)
|
||||||
|
@ -7,7 +7,6 @@ from collections.abc import Awaitable
|
|||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Literal,
|
Literal,
|
||||||
@ -15,7 +14,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import Field, SkipValidation
|
from pydantic import BaseModel, Field, SkipValidation
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -30,7 +29,6 @@ from langchain_core.tools.base import (
|
|||||||
_get_runnable_config_param,
|
_get_runnable_config_param,
|
||||||
create_schema_from_function,
|
create_schema_from_function,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.messages import ToolCall
|
from langchain_core.messages import ToolCall
|
||||||
@ -40,9 +38,7 @@ class StructuredTool(BaseTool):
|
|||||||
"""Tool that can operate on any number of inputs."""
|
"""Tool that can operate on any number of inputs."""
|
||||||
|
|
||||||
description: str = ""
|
description: str = ""
|
||||||
args_schema: Annotated[ArgsSchema, SkipValidation()] = Field(
|
args_schema: SkipValidation[ArgsSchema] = Field(..., description="The tool schema.")
|
||||||
..., description="The tool schema."
|
|
||||||
)
|
|
||||||
"""The input arguments' schema."""
|
"""The input arguments' schema."""
|
||||||
func: Optional[Callable[..., Any]] = None
|
func: Optional[Callable[..., Any]] = None
|
||||||
"""The function to run when the tool is called."""
|
"""The function to run when the tool is called."""
|
||||||
@ -196,7 +192,7 @@ class StructuredTool(BaseTool):
|
|||||||
if description is None and not parse_docstring:
|
if description is None and not parse_docstring:
|
||||||
description_ = source_function.__doc__ or None
|
description_ = source_function.__doc__ or None
|
||||||
if description_ is None and args_schema:
|
if description_ is None and args_schema:
|
||||||
if isinstance(args_schema, type) and is_basemodel_subclass(args_schema):
|
if isinstance(args_schema, type) and issubclass(args_schema, BaseModel):
|
||||||
description_ = args_schema.__doc__ or None
|
description_ = args_schema.__doc__ or None
|
||||||
elif isinstance(args_schema, dict):
|
elif isinstance(args_schema, dict):
|
||||||
description_ = args_schema.get("description")
|
description_ = args_schema.get("description")
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
|
|||||||
class TracerSessionV1Base(BaseModelV1):
|
class TracerSessionV1Base(BaseModelV1):
|
||||||
"""Base class for TracerSessionV1."""
|
"""Base class for TracerSessionV1."""
|
||||||
|
|
||||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
extra: Optional[dict[str, Any]] = None
|
extra: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
@ -69,8 +69,8 @@ class BaseRun(BaseModelV1):
|
|||||||
|
|
||||||
uuid: str
|
uuid: str
|
||||||
parent_uuid: Optional[str] = None
|
parent_uuid: Optional[str] = None
|
||||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
extra: Optional[dict[str, Any]] = None
|
extra: Optional[dict[str, Any]] = None
|
||||||
execution_order: int
|
execution_order: int
|
||||||
child_execution_order: int
|
child_execution_order: int
|
||||||
|
@ -20,13 +20,11 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
||||||
|
|
||||||
from langchain_core._api import beta, deprecated
|
from langchain_core._api import beta, deprecated
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||||
from langchain_core.utils.json_schema import dereference_refs
|
from langchain_core.utils.json_schema import dereference_refs
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@ -150,9 +148,7 @@ def _convert_pydantic_to_openai_function(
|
|||||||
The function description.
|
The function description.
|
||||||
"""
|
"""
|
||||||
if hasattr(model, "model_json_schema"):
|
if hasattr(model, "model_json_schema"):
|
||||||
schema = model.model_json_schema() # Pydantic 2
|
schema = model.model_json_schema()
|
||||||
elif hasattr(model, "schema"):
|
|
||||||
schema = model.schema() # Pydantic 1
|
|
||||||
else:
|
else:
|
||||||
msg = "Model must be a Pydantic model."
|
msg = "Model must be a Pydantic model."
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
@ -249,6 +245,7 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
|
|||||||
"type[BaseModel]",
|
"type[BaseModel]",
|
||||||
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
||||||
)
|
)
|
||||||
|
print(model)
|
||||||
return _convert_pydantic_to_openai_function(model)
|
return _convert_pydantic_to_openai_function(model)
|
||||||
|
|
||||||
|
|
||||||
@ -336,7 +333,7 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
|||||||
return _convert_json_schema_to_openai_function(
|
return _convert_json_schema_to_openai_function(
|
||||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||||
)
|
)
|
||||||
if issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
|
if issubclass(tool.tool_call_schema, BaseModel):
|
||||||
return _convert_pydantic_to_openai_function(
|
return _convert_pydantic_to_openai_function(
|
||||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||||
)
|
)
|
||||||
@ -466,7 +463,7 @@ def convert_to_openai_function(
|
|||||||
oai_function["description"] = function_copy.pop("description")
|
oai_function["description"] = function_copy.pop("description")
|
||||||
if function_copy and "properties" in function_copy:
|
if function_copy and "properties" in function_copy:
|
||||||
oai_function["parameters"] = function_copy
|
oai_function["parameters"] = function_copy
|
||||||
elif isinstance(function, type) and is_basemodel_subclass(function):
|
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||||
oai_function = cast("dict", _convert_pydantic_to_openai_function(function))
|
oai_function = cast("dict", _convert_pydantic_to_openai_function(function))
|
||||||
elif is_typeddict(function):
|
elif is_typeddict(function):
|
||||||
oai_function = cast(
|
oai_function = cast(
|
||||||
|
@ -2,173 +2,39 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
|
||||||
import textwrap
|
import textwrap
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from types import GenericAlias
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Optional,
|
Optional,
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
overload,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import pydantic
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import (
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
|
||||||
BaseModel,
|
|
||||||
ConfigDict,
|
|
||||||
PydanticDeprecationWarning,
|
|
||||||
RootModel,
|
|
||||||
root_validator,
|
|
||||||
)
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
create_model as _create_model_base,
|
create_model as _create_model_base,
|
||||||
)
|
)
|
||||||
from pydantic.fields import FieldInfo as FieldInfoV2
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic.json_schema import (
|
from pydantic.json_schema import (
|
||||||
DEFAULT_REF_TEMPLATE,
|
DEFAULT_REF_TEMPLATE,
|
||||||
GenerateJsonSchema,
|
GenerateJsonSchema,
|
||||||
JsonSchemaMode,
|
JsonSchemaMode,
|
||||||
JsonSchemaValue,
|
JsonSchemaValue,
|
||||||
)
|
)
|
||||||
|
from pydantic.version import VERSION as PYDANTIC_VERSION_STRING
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
try:
|
PYDANTIC_VERSION = version.parse(PYDANTIC_VERSION_STRING)
|
||||||
import pydantic
|
|
||||||
|
|
||||||
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
|
||||||
except ImportError:
|
|
||||||
PYDANTIC_VERSION = version.parse("0.0.0")
|
|
||||||
|
|
||||||
|
|
||||||
def get_pydantic_major_version() -> int:
|
|
||||||
"""DEPRECATED - Get the major version of Pydantic.
|
|
||||||
|
|
||||||
Use PYDANTIC_VERSION.major instead.
|
|
||||||
"""
|
|
||||||
warnings.warn(
|
|
||||||
"get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return PYDANTIC_VERSION.major
|
|
||||||
|
|
||||||
|
|
||||||
PYDANTIC_MAJOR_VERSION = PYDANTIC_VERSION.major
|
|
||||||
PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
|
PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
|
||||||
|
|
||||||
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
|
|
||||||
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
|
|
||||||
|
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
from pydantic.fields import FieldInfo as FieldInfoV1
|
|
||||||
|
|
||||||
PydanticBaseModel = pydantic.BaseModel
|
|
||||||
TypeBaseModel = type[BaseModel]
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
|
||||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
|
||||||
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
def is_pydantic_v1_subclass(cls: type) -> bool:
|
|
||||||
"""Check if the installed Pydantic version is 1.x-like."""
|
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
return True
|
|
||||||
if IS_PYDANTIC_V2:
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_pydantic_v2_subclass(cls: type) -> bool:
|
|
||||||
"""Check if the installed Pydantic version is 1.x-like."""
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
return IS_PYDANTIC_V2 and issubclass(cls, BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_subclass(cls: type) -> bool:
|
|
||||||
"""Check if the given class is a subclass of Pydantic BaseModel.
|
|
||||||
|
|
||||||
Check if the given class is a subclass of any of the following:
|
|
||||||
|
|
||||||
* pydantic.BaseModel in Pydantic 1.x
|
|
||||||
* pydantic.BaseModel in Pydantic 2.x
|
|
||||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
|
||||||
"""
|
|
||||||
# Before we can use issubclass on the cls we need to check if it is a class
|
|
||||||
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1Proper):
|
|
||||||
return True
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV2):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_instance(obj: Any) -> bool:
|
|
||||||
"""Check if the given class is an instance of Pydantic BaseModel.
|
|
||||||
|
|
||||||
Check if the given class is an instance of any of the following:
|
|
||||||
|
|
||||||
* pydantic.BaseModel in Pydantic 1.x
|
|
||||||
* pydantic.BaseModel in Pydantic 2.x
|
|
||||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
|
||||||
"""
|
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
|
||||||
|
|
||||||
if isinstance(obj, BaseModelV1Proper):
|
|
||||||
return True
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
if isinstance(obj, BaseModelV2):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if isinstance(obj, BaseModelV1):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# How to type hint this?
|
# How to type hint this?
|
||||||
def pre_init(func: Callable) -> Any:
|
def pre_init(func: Callable) -> Any:
|
||||||
@ -180,50 +46,35 @@ def pre_init(func: Callable) -> Any:
|
|||||||
Returns:
|
Returns:
|
||||||
Any: The decorated function.
|
Any: The decorated function.
|
||||||
"""
|
"""
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
|
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Decorator to run a function before model initialization.
|
"""Decorator to run a function before model initialization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cls (Type[BaseModel]): The model class.
|
cls (Type[BaseModel]): The model class.
|
||||||
values (dict[str, Any]): The values to initialize the model with.
|
values (dict[str, Any]): The values to initialize the model with.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Any]: The values to initialize the model with.
|
dict[str, Any]: The values to initialize the model with.
|
||||||
"""
|
"""
|
||||||
# Insert default values
|
# Insert default values
|
||||||
fields = cls.model_fields
|
fields = cls.model_fields
|
||||||
for name, field_info in fields.items():
|
for name, field_info in fields.items():
|
||||||
# Check if allow_population_by_field_name is enabled
|
if cls.model_config.get("populate_by_name") and field_info.alias in values:
|
||||||
# If yes, then set the field name to the alias
|
values[name] = values.pop(field_info.alias)
|
||||||
if (
|
|
||||||
hasattr(cls, "Config")
|
|
||||||
and hasattr(cls.Config, "allow_population_by_field_name")
|
|
||||||
and cls.Config.allow_population_by_field_name
|
|
||||||
and field_info.alias in values
|
|
||||||
):
|
|
||||||
values[name] = values.pop(field_info.alias)
|
|
||||||
if (
|
|
||||||
hasattr(cls, "model_config")
|
|
||||||
and cls.model_config.get("populate_by_name")
|
|
||||||
and field_info.alias in values
|
|
||||||
):
|
|
||||||
values[name] = values.pop(field_info.alias)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
name not in values or values[name] is None
|
name not in values or values[name] is None
|
||||||
) and not field_info.is_required():
|
) and not field_info.is_required():
|
||||||
if field_info.default_factory is not None:
|
if field_info.default_factory is not None:
|
||||||
values[name] = field_info.default_factory() # type: ignore[call-arg]
|
values[name] = field_info.default_factory() # type: ignore[call-arg]
|
||||||
else:
|
else:
|
||||||
values[name] = field_info.default
|
values[name] = field_info.default
|
||||||
|
|
||||||
# Call the decorated function
|
# Call the decorated function
|
||||||
return func(cls, values)
|
return func(cls, values)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -241,55 +92,15 @@ class _IgnoreUnserializable(GenerateJsonSchema):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _create_subset_model_v1(
|
|
||||||
name: str,
|
|
||||||
model: type[BaseModel],
|
|
||||||
field_names: list,
|
|
||||||
*,
|
|
||||||
descriptions: Optional[dict] = None,
|
|
||||||
fn_description: Optional[str] = None,
|
|
||||||
) -> type[BaseModel]:
|
|
||||||
"""Create a pydantic model with only a subset of model's fields."""
|
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
from pydantic import create_model
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic.v1 import create_model # type: ignore[no-redef]
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
fields = {}
|
|
||||||
|
|
||||||
for field_name in field_names:
|
|
||||||
# Using pydantic v1 so can access __fields__ as a dict.
|
|
||||||
field = model.__fields__[field_name] # type: ignore[index]
|
|
||||||
t = (
|
|
||||||
# this isn't perfect but should work for most functions
|
|
||||||
field.outer_type_
|
|
||||||
if field.required and not field.allow_none
|
|
||||||
else Optional[field.outer_type_]
|
|
||||||
)
|
|
||||||
if descriptions and field_name in descriptions:
|
|
||||||
field.field_info.description = descriptions[field_name]
|
|
||||||
fields[field_name] = (t, field.field_info)
|
|
||||||
|
|
||||||
rtn = create_model(name, **fields) # type: ignore[call-overload]
|
|
||||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
|
||||||
return rtn
|
|
||||||
|
|
||||||
|
|
||||||
def _create_subset_model_v2(
|
def _create_subset_model_v2(
|
||||||
name: str,
|
name: str,
|
||||||
model: type[pydantic.BaseModel],
|
model: type[BaseModel],
|
||||||
field_names: list[str],
|
field_names: list[str],
|
||||||
*,
|
*,
|
||||||
descriptions: Optional[dict] = None,
|
descriptions: Optional[dict] = None,
|
||||||
fn_description: Optional[str] = None,
|
fn_description: Optional[str] = None,
|
||||||
) -> type[pydantic.BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a pydantic model with a subset of the model fields."""
|
"""Create a pydantic model with a subset of the model fields."""
|
||||||
from pydantic import create_model
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
|
|
||||||
descriptions_ = descriptions or {}
|
descriptions_ = descriptions or {}
|
||||||
fields = {}
|
fields = {}
|
||||||
for field_name in field_names:
|
for field_name in field_names:
|
||||||
@ -300,7 +111,7 @@ def _create_subset_model_v2(
|
|||||||
field_info.metadata = field.metadata
|
field_info.metadata = field.metadata
|
||||||
fields[field_name] = (field.annotation, field_info)
|
fields[field_name] = (field.annotation, field_info)
|
||||||
|
|
||||||
rtn = create_model( # type: ignore[call-overload]
|
rtn = _create_model_base(
|
||||||
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
|
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -319,89 +130,34 @@ def _create_subset_model_v2(
|
|||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
# Private functionality to create a subset model that's compatible across
|
|
||||||
# different versions of pydantic.
|
|
||||||
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
|
|
||||||
# However, can't find a way to type hint this.
|
|
||||||
def _create_subset_model(
|
def _create_subset_model(
|
||||||
name: str,
|
name: str,
|
||||||
model: TypeBaseModel,
|
model: type[BaseModel],
|
||||||
field_names: list[str],
|
field_names: list[str],
|
||||||
*,
|
*,
|
||||||
descriptions: Optional[dict] = None,
|
descriptions: Optional[dict] = None,
|
||||||
fn_description: Optional[str] = None,
|
fn_description: Optional[str] = None,
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create subset model using the same pydantic version as the input model."""
|
"""Create subset model using the same pydantic version as the input model."""
|
||||||
if IS_PYDANTIC_V1:
|
return _create_subset_model_v2(
|
||||||
return _create_subset_model_v1(
|
name,
|
||||||
name,
|
model,
|
||||||
model,
|
field_names,
|
||||||
field_names,
|
descriptions=descriptions,
|
||||||
descriptions=descriptions,
|
fn_description=fn_description,
|
||||||
fn_description=fn_description,
|
)
|
||||||
)
|
|
||||||
if IS_PYDANTIC_V2:
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
if issubclass(model, BaseModelV1):
|
|
||||||
return _create_subset_model_v1(
|
|
||||||
name,
|
|
||||||
model,
|
|
||||||
field_names,
|
|
||||||
descriptions=descriptions,
|
|
||||||
fn_description=fn_description,
|
|
||||||
)
|
|
||||||
return _create_subset_model_v2(
|
|
||||||
name,
|
|
||||||
model,
|
|
||||||
field_names,
|
|
||||||
descriptions=descriptions,
|
|
||||||
fn_description=fn_description,
|
|
||||||
)
|
|
||||||
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
if IS_PYDANTIC_V2:
|
def get_fields(
|
||||||
from pydantic import BaseModel as BaseModelV2
|
model: type[BaseModel],
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
) -> dict[str, FieldInfo]:
|
||||||
|
"""Get the field names of a Pydantic model."""
|
||||||
@overload
|
try:
|
||||||
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
|
return model.model_fields
|
||||||
|
except AttributeError as exc:
|
||||||
@overload
|
|
||||||
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
|
|
||||||
|
|
||||||
def get_fields(
|
|
||||||
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
|
||||||
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
|
|
||||||
"""Get the field names of a Pydantic model."""
|
|
||||||
if hasattr(model, "model_fields"):
|
|
||||||
return model.model_fields
|
|
||||||
|
|
||||||
if hasattr(model, "__fields__"):
|
|
||||||
return model.__fields__ # type: ignore[return-value]
|
|
||||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg) from exc
|
||||||
|
|
||||||
elif IS_PYDANTIC_V1:
|
|
||||||
from pydantic import BaseModel as BaseModelV1_
|
|
||||||
|
|
||||||
def get_fields( # type: ignore[no-redef]
|
|
||||||
model: Union[type[BaseModelV1_], BaseModelV1_],
|
|
||||||
) -> dict[str, FieldInfoV1]:
|
|
||||||
"""Get the field names of a Pydantic model."""
|
|
||||||
return model.__fields__ # type: ignore[return-value]
|
|
||||||
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
_SchemaConfig = ConfigDict(
|
_SchemaConfig = ConfigDict(
|
||||||
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
||||||
@ -458,17 +214,6 @@ def _create_root_model(
|
|||||||
if default_ is not NO_DEFAULT:
|
if default_ is not NO_DEFAULT:
|
||||||
base_class_attributes["root"] = default_
|
base_class_attributes["root"] = default_
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
try:
|
|
||||||
if (
|
|
||||||
isinstance(type_, type)
|
|
||||||
and not isinstance(type_, GenericAlias)
|
|
||||||
and issubclass(type_, BaseModelV1)
|
|
||||||
):
|
|
||||||
warnings.filterwarnings(
|
|
||||||
action="ignore", category=PydanticDeprecationWarning
|
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
pass
|
|
||||||
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
||||||
return cast("type[BaseModel]", custom_root_type)
|
return cast("type[BaseModel]", custom_root_type)
|
||||||
|
|
||||||
@ -545,9 +290,6 @@ _RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
|
|||||||
|
|
||||||
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
||||||
from pydantic import Field
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
|
|
||||||
remapped = {}
|
remapped = {}
|
||||||
for key, value in field_definitions.items():
|
for key, value in field_definitions.items():
|
||||||
if key.startswith("_") or key in _RESERVED_NAMES:
|
if key.startswith("_") or key in _RESERVED_NAMES:
|
||||||
@ -595,7 +337,7 @@ def create_model_v2(
|
|||||||
root: Type for a root model (RootModel)
|
root: Type for a root model (RootModel)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Type[BaseModel]: The created model.
|
type[BaseModel]: The created model.
|
||||||
"""
|
"""
|
||||||
field_definitions = field_definitions or {}
|
field_definitions = field_definitions or {}
|
||||||
|
|
||||||
|
@ -15,10 +15,6 @@ from pydantic import SecretStr
|
|||||||
from requests import HTTPError, Response
|
from requests import HTTPError, Response
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.utils.pydantic import (
|
|
||||||
is_pydantic_v1_subclass,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def xor_args(*arg_groups: tuple[str, ...]) -> Callable:
|
def xor_args(*arg_groups: tuple[str, ...]) -> Callable:
|
||||||
"""Validate specified keyword args are mutually exclusive.".
|
"""Validate specified keyword args are mutually exclusive.".
|
||||||
@ -206,16 +202,10 @@ def get_pydantic_field_names(pydantic_cls: Any) -> set[str]:
|
|||||||
set[str]: Field names.
|
set[str]: Field names.
|
||||||
"""
|
"""
|
||||||
all_required_field_names = set()
|
all_required_field_names = set()
|
||||||
if is_pydantic_v1_subclass(pydantic_cls):
|
for name, field in pydantic_cls.model_fields.items():
|
||||||
for field in pydantic_cls.__fields__.values():
|
all_required_field_names.add(name)
|
||||||
all_required_field_names.add(field.name)
|
if field.alias:
|
||||||
if field.has_alias:
|
all_required_field_names.add(field.alias)
|
||||||
all_required_field_names.add(field.alias)
|
|
||||||
else: # Assuming pydantic 2 for now
|
|
||||||
for name, field in pydantic_cls.model_fields.items():
|
|
||||||
all_required_field_names.add(name)
|
|
||||||
if field.alias:
|
|
||||||
all_required_field_names.add(field.alias)
|
|
||||||
return all_required_field_names
|
return all_required_field_names
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ filterwarnings = [ "ignore::langchain_core._api.beta_decorator.LangChainBetaWarn
|
|||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
|
||||||
[tool.ruff.lint.pep8-naming]
|
[tool.ruff.lint.pep8-naming]
|
||||||
classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
|
classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator"]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"langchain_core/utils/mustache.py" = [ "PLW0603",]
|
"langchain_core/utils/mustache.py" = [ "PLW0603",]
|
||||||
|
@ -16,10 +16,6 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration
|
from langchain_core.outputs import ChatGeneration
|
||||||
from langchain_core.utils.pydantic import (
|
|
||||||
IS_PYDANTIC_V1,
|
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
)
|
|
||||||
|
|
||||||
STREAMED_MESSAGES: list = [
|
STREAMED_MESSAGES: list = [
|
||||||
AIMessageChunk(content=""),
|
AIMessageChunk(content=""),
|
||||||
@ -532,87 +528,13 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
|||||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
def test_parse_with_different_pydantic() -> None:
|
||||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
"""Test with BaseModel"""
|
||||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
|
||||||
import pydantic
|
|
||||||
|
|
||||||
class Forecast(pydantic.v1.BaseModel):
|
class Forecast(BaseModel):
|
||||||
temperature: int
|
temperature: int
|
||||||
forecast: str
|
forecast: str
|
||||||
|
|
||||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
|
||||||
# both v1 and v2 in the same codebase.
|
|
||||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
|
||||||
message = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "call_OwL7f5PE",
|
|
||||||
"name": "Forecast",
|
|
||||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
generation = ChatGeneration(
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert parser.parse_result([generation]) == [
|
|
||||||
Forecast(
|
|
||||||
temperature=20,
|
|
||||||
forecast="Sunny",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
|
|
||||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
|
||||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
|
||||||
import pydantic
|
|
||||||
|
|
||||||
class Forecast(pydantic.BaseModel):
|
|
||||||
temperature: int
|
|
||||||
forecast: str
|
|
||||||
|
|
||||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
|
||||||
# both v1 and v2 in the same codebase.
|
|
||||||
parser = PydanticToolsParser(tools=[Forecast])
|
|
||||||
message = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "call_OwL7f5PE",
|
|
||||||
"name": "Forecast",
|
|
||||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
generation = ChatGeneration(
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert parser.parse_result([generation]) == [
|
|
||||||
Forecast(
|
|
||||||
temperature=20,
|
|
||||||
forecast="Sunny",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1")
|
|
||||||
def test_parse_with_different_pydantic_1_proper() -> None:
|
|
||||||
"""Test with pydantic.BaseModel from pydantic 1."""
|
|
||||||
import pydantic
|
|
||||||
|
|
||||||
class Forecast(pydantic.BaseModel):
|
|
||||||
temperature: int
|
|
||||||
forecast: str
|
|
||||||
|
|
||||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
|
||||||
# both v1 and v2 in the same codebase.
|
|
||||||
parser = PydanticToolsParser(tools=[Forecast])
|
parser = PydanticToolsParser(tools=[Forecast])
|
||||||
message = AIMessage(
|
message = AIMessage(
|
||||||
content="",
|
content="",
|
||||||
|
@ -3,35 +3,23 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import pydantic
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic.v1 import BaseModel as V1BaseModel
|
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.language_models import ParrotFakeChatModel
|
from langchain_core.language_models import ParrotFakeChatModel
|
||||||
from langchain_core.output_parsers import PydanticOutputParser
|
from langchain_core.output_parsers import PydanticOutputParser
|
||||||
from langchain_core.output_parsers.json import JsonOutputParser
|
from langchain_core.output_parsers.json import JsonOutputParser
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.utils.pydantic import TBaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class ForecastV2(pydantic.BaseModel):
|
class Forecast(BaseModel):
|
||||||
temperature: int
|
temperature: int
|
||||||
f_or_c: Literal["F", "C"]
|
f_or_c: Literal["F", "C"]
|
||||||
forecast: str
|
forecast: str
|
||||||
|
|
||||||
|
|
||||||
class ForecastV1(V1BaseModel):
|
def test_pydantic_parser_chaining() -> None:
|
||||||
temperature: int
|
|
||||||
f_or_c: Literal["F", "C"]
|
|
||||||
forecast: str
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
|
||||||
def test_pydantic_parser_chaining(
|
|
||||||
pydantic_object: TBaseModel,
|
|
||||||
) -> None:
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template="""{{
|
template="""{{
|
||||||
"temperature": 20,
|
"temperature": 20,
|
||||||
@ -43,18 +31,17 @@ def test_pydantic_parser_chaining(
|
|||||||
|
|
||||||
model = ParrotFakeChatModel()
|
model = ParrotFakeChatModel()
|
||||||
|
|
||||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
|
parser = PydanticOutputParser(pydantic_object=Forecast)
|
||||||
chain = prompt | model | parser
|
chain = prompt | model | parser
|
||||||
|
|
||||||
res = chain.invoke({})
|
res = chain.invoke({})
|
||||||
assert type(res) is pydantic_object
|
assert type(res) is Forecast
|
||||||
assert res.f_or_c == "C"
|
assert res.f_or_c == "C"
|
||||||
assert res.temperature == 20
|
assert res.temperature == 20
|
||||||
assert res.forecast == "Sunny"
|
assert res.forecast == "Sunny"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
def test_pydantic_parser_validation() -> None:
|
||||||
def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
|
|
||||||
bad_prompt = PromptTemplate(
|
bad_prompt = PromptTemplate(
|
||||||
template="""{{
|
template="""{{
|
||||||
"temperature": "oof",
|
"temperature": "oof",
|
||||||
@ -66,17 +53,14 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
|
|||||||
|
|
||||||
model = ParrotFakeChatModel()
|
model = ParrotFakeChatModel()
|
||||||
|
|
||||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
|
parser = PydanticOutputParser(pydantic_object=Forecast)
|
||||||
chain = bad_prompt | model | parser
|
chain = bad_prompt | model | parser
|
||||||
with pytest.raises(OutputParserException):
|
with pytest.raises(OutputParserException):
|
||||||
chain.invoke({})
|
chain.invoke({})
|
||||||
|
|
||||||
|
|
||||||
# JSON output parser tests
|
# JSON output parser tests
|
||||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
def test_json_parser_chaining() -> None:
|
||||||
def test_json_parser_chaining(
|
|
||||||
pydantic_object: TBaseModel,
|
|
||||||
) -> None:
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template="""{{
|
template="""{{
|
||||||
"temperature": 20,
|
"temperature": 20,
|
||||||
@ -88,7 +72,7 @@ def test_json_parser_chaining(
|
|||||||
|
|
||||||
model = ParrotFakeChatModel()
|
model = ParrotFakeChatModel()
|
||||||
|
|
||||||
parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type]
|
parser = JsonOutputParser(pydantic_object=Forecast)
|
||||||
chain = prompt | model | parser
|
chain = prompt | model | parser
|
||||||
|
|
||||||
res = chain.invoke({})
|
res = chain.invoke({})
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import isclass
|
from typing import Any, Union
|
||||||
from typing import Any, Union, cast
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -10,15 +9,14 @@ from langchain_core.load.load import loads
|
|||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.prompts.structured import StructuredPrompt
|
from langchain_core.prompts.structured import StructuredPrompt
|
||||||
from langchain_core.runnables.base import Runnable, RunnableLambda
|
from langchain_core.runnables.base import Runnable, RunnableLambda
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
|
||||||
|
|
||||||
|
|
||||||
def _fake_runnable(
|
def _fake_runnable(
|
||||||
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
|
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
|
||||||
) -> Union[BaseModel, dict]:
|
) -> Union[BaseModel, dict]:
|
||||||
if isclass(schema) and is_basemodel_subclass(schema):
|
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||||
return schema(name="yo", value=value)
|
return schema(name="yo", value=value)
|
||||||
params = cast("dict", schema)["parameters"]
|
params = schema["parameters"]
|
||||||
return {k: 1 if k != "value" else value for k, v in params.items()}
|
return {k: 1 if k != "value" else value for k, v in params.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,8 +2,6 @@ from typing import Any, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
|
||||||
|
|
||||||
|
|
||||||
# Function to replace allOf with $ref
|
# Function to replace allOf with $ref
|
||||||
def replace_all_of_with_ref(schema: Any) -> None:
|
def replace_all_of_with_ref(schema: Any) -> None:
|
||||||
@ -31,8 +29,6 @@ def replace_all_of_with_ref(schema: Any) -> None:
|
|||||||
def remove_all_none_default(schema: Any) -> None:
|
def remove_all_none_default(schema: Any) -> None:
|
||||||
"""Removing all none defaults.
|
"""Removing all none defaults.
|
||||||
|
|
||||||
Pydantic v1 did not generate these, but Pydantic v2 does.
|
|
||||||
|
|
||||||
The None defaults usually represent **NotRequired** fields, and the None value
|
The None defaults usually represent **NotRequired** fields, and the None value
|
||||||
is actually **incorrect** as a value since the fields do not allow a None value.
|
is actually **incorrect** as a value since the fields do not allow a None value.
|
||||||
|
|
||||||
@ -75,13 +71,9 @@ def _remove_enum(obj: Any) -> None:
|
|||||||
|
|
||||||
def _schema(obj: Any) -> dict:
|
def _schema(obj: Any) -> dict:
|
||||||
"""Return the schema of the object."""
|
"""Return the schema of the object."""
|
||||||
if not is_basemodel_subclass(obj):
|
if not (isinstance(obj, type) and issubclass(obj, BaseModel)):
|
||||||
msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
|
msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
# Remap to old style schema
|
|
||||||
if not hasattr(obj, "model_json_schema"): # V1 model
|
|
||||||
return obj.schema()
|
|
||||||
|
|
||||||
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
|
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
|
||||||
if "$defs" in schema_:
|
if "$defs" in schema_:
|
||||||
schema_["definitions"] = schema_["$defs"]
|
schema_["definitions"] = schema_["$defs"]
|
||||||
|
@ -21,9 +21,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain_core import tools
|
from langchain_core import tools
|
||||||
@ -52,7 +50,6 @@ from langchain_core.tools import (
|
|||||||
tool,
|
tool,
|
||||||
)
|
)
|
||||||
from langchain_core.tools.base import (
|
from langchain_core.tools.base import (
|
||||||
ArgsSchema,
|
|
||||||
InjectedToolArg,
|
InjectedToolArg,
|
||||||
InjectedToolCallId,
|
InjectedToolCallId,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
@ -65,8 +62,6 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
IS_PYDANTIC_V1,
|
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
_create_subset_model,
|
_create_subset_model,
|
||||||
create_model_v2,
|
create_model_v2,
|
||||||
)
|
)
|
||||||
@ -79,9 +74,7 @@ def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
|||||||
if isinstance(tool_schema, dict):
|
if isinstance(tool_schema, dict):
|
||||||
return tool_schema
|
return tool_schema
|
||||||
|
|
||||||
if hasattr(tool_schema, "model_json_schema"):
|
return tool_schema.model_json_schema()
|
||||||
return tool_schema.model_json_schema()
|
|
||||||
return tool_schema.schema()
|
|
||||||
|
|
||||||
|
|
||||||
def test_unnamed_decorator() -> None:
|
def test_unnamed_decorator() -> None:
|
||||||
@ -106,14 +99,6 @@ class _MockSchema(BaseModel):
|
|||||||
arg3: Optional[dict] = None
|
arg3: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class _MockSchemaV1(BaseModelV1):
|
|
||||||
"""Return the arguments directly."""
|
|
||||||
|
|
||||||
arg1: int
|
|
||||||
arg2: bool
|
|
||||||
arg3: Optional[dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
class _MockStructuredTool(BaseTool):
|
class _MockStructuredTool(BaseTool):
|
||||||
name: str = "structured_api"
|
name: str = "structured_api"
|
||||||
args_schema: type[BaseModel] = _MockSchema
|
args_schema: type[BaseModel] = _MockSchema
|
||||||
@ -205,13 +190,6 @@ def test_decorator_with_specified_schema() -> None:
|
|||||||
assert isinstance(tool_func, BaseTool)
|
assert isinstance(tool_func, BaseTool)
|
||||||
assert tool_func.args_schema == _MockSchema
|
assert tool_func.args_schema == _MockSchema
|
||||||
|
|
||||||
@tool(args_schema=cast("ArgsSchema", _MockSchemaV1))
|
|
||||||
def tool_func_v1(*, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
||||||
return f"{arg1} {arg2} {arg3}"
|
|
||||||
|
|
||||||
assert isinstance(tool_func_v1, BaseTool)
|
|
||||||
assert tool_func_v1.args_schema == _MockSchemaV1
|
|
||||||
|
|
||||||
|
|
||||||
def test_decorated_function_schema_equivalent() -> None:
|
def test_decorated_function_schema_equivalent() -> None:
|
||||||
"""Test that a BaseTool without a schema meets expectations."""
|
"""Test that a BaseTool without a schema meets expectations."""
|
||||||
@ -345,50 +323,6 @@ def test_structured_tool_types_parsed() -> None:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
def test_structured_tool_types_parsed_pydantic_v1() -> None:
|
|
||||||
"""Test the non-primitive types are correctly passed to structured tools."""
|
|
||||||
|
|
||||||
class SomeBaseModel(BaseModelV1):
|
|
||||||
foo: str
|
|
||||||
|
|
||||||
class AnotherBaseModel(BaseModelV1):
|
|
||||||
bar: str
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel:
|
|
||||||
"""Return the arguments directly."""
|
|
||||||
return AnotherBaseModel(bar=some_base_model.foo)
|
|
||||||
|
|
||||||
assert isinstance(structured_tool, StructuredTool)
|
|
||||||
|
|
||||||
expected = AnotherBaseModel(bar="baz")
|
|
||||||
for arg in [
|
|
||||||
SomeBaseModel(foo="baz"),
|
|
||||||
SomeBaseModel(foo="baz").dict(),
|
|
||||||
]:
|
|
||||||
args = {"some_base_model": arg}
|
|
||||||
result = structured_tool.run(args)
|
|
||||||
assert result == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_structured_tool_types_parsed_pydantic_mixed() -> None:
|
|
||||||
"""Test handling of tool with mixed Pydantic version arguments."""
|
|
||||||
|
|
||||||
class SomeBaseModel(BaseModelV1):
|
|
||||||
foo: str
|
|
||||||
|
|
||||||
class AnotherBaseModel(BaseModel):
|
|
||||||
bar: str
|
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def structured_tool(
|
|
||||||
some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel
|
|
||||||
) -> None:
|
|
||||||
"""Return the arguments directly."""
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_tool_inheritance_base_schema() -> None:
|
def test_base_tool_inheritance_base_schema() -> None:
|
||||||
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
||||||
|
|
||||||
@ -867,7 +801,7 @@ def test_validation_error_handling_callable() -> None:
|
|||||||
"""Test that validation errors are handled correctly."""
|
"""Test that validation errors are handled correctly."""
|
||||||
expected = "foo bar"
|
expected = "foo bar"
|
||||||
|
|
||||||
def handling(e: Union[ValidationError, ValidationErrorV1]) -> str:
|
def handling(e: ValidationError) -> str:
|
||||||
return expected
|
return expected
|
||||||
|
|
||||||
_tool = _MockStructuredTool(handle_validation_error=handling)
|
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||||
@ -884,9 +818,7 @@ def test_validation_error_handling_callable() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_validation_error_handling_non_validation_error(
|
def test_validation_error_handling_non_validation_error(
|
||||||
handler: Union[
|
handler: Union[bool, str, Callable[[ValidationError], str]],
|
||||||
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
|
||||||
],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that validation errors are handled correctly."""
|
"""Test that validation errors are handled correctly."""
|
||||||
|
|
||||||
@ -932,7 +864,7 @@ async def test_async_validation_error_handling_callable() -> None:
|
|||||||
"""Test that validation errors are handled correctly."""
|
"""Test that validation errors are handled correctly."""
|
||||||
expected = "foo bar"
|
expected = "foo bar"
|
||||||
|
|
||||||
def handling(e: Union[ValidationError, ValidationErrorV1]) -> str:
|
def handling(e: ValidationError) -> str:
|
||||||
return expected
|
return expected
|
||||||
|
|
||||||
_tool = _MockStructuredTool(handle_validation_error=handling)
|
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||||
@ -949,9 +881,7 @@ async def test_async_validation_error_handling_callable() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_async_validation_error_handling_non_validation_error(
|
async def test_async_validation_error_handling_non_validation_error(
|
||||||
handler: Union[
|
handler: Union[bool, str, Callable[[ValidationError], str]],
|
||||||
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
|
||||||
],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that validation errors are handled correctly."""
|
"""Test that validation errors are handled correctly."""
|
||||||
|
|
||||||
@ -1812,38 +1742,18 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_models() -> list[Any]:
|
class FooProper(BaseModel):
|
||||||
"""Generate a list of base models depending on the pydantic version."""
|
a: int
|
||||||
|
b: str
|
||||||
class FooProper(BaseModel):
|
|
||||||
a: int
|
|
||||||
b: str
|
|
||||||
|
|
||||||
return [FooProper]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_backwards_compatible_v1() -> list[Any]:
|
TEST_MODELS = [FooProper]
|
||||||
"""Generate a model with pydantic 2 from the v1 namespace."""
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
class FooV1Namespace(BaseModelV1):
|
|
||||||
a: int
|
|
||||||
b: str
|
|
||||||
|
|
||||||
return [FooV1Namespace]
|
|
||||||
|
|
||||||
|
|
||||||
# This generates a list of models that can be used for testing that our APIs
|
|
||||||
# behave well with either pydantic 1 proper,
|
|
||||||
# pydantic v1 from pydantic 2,
|
|
||||||
# or pydantic 2 proper.
|
|
||||||
TEST_MODELS = generate_models() + generate_backwards_compatible_v1()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||||
def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||||
class SomeTool(BaseTool):
|
class SomeTool(BaseTool):
|
||||||
args_schema: type[pydantic_model] = pydantic_model
|
args_schema: type[BaseModel] = pydantic_model
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||||
return "foo"
|
return "foo"
|
||||||
@ -1853,11 +1763,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
input_schema = tool.get_input_schema()
|
input_schema = tool.get_input_schema()
|
||||||
input_json_schema = (
|
input_json_schema = input_schema.model_json_schema()
|
||||||
input_schema.model_json_schema()
|
|
||||||
if hasattr(input_schema, "model_json_schema")
|
|
||||||
else input_schema.schema()
|
|
||||||
)
|
|
||||||
assert input_json_schema == {
|
assert input_json_schema == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {"title": "A", "type": "integer"},
|
"a": {"title": "A", "type": "integer"},
|
||||||
@ -1882,22 +1788,13 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_args_schema_explicitly_typed() -> None:
|
def test_args_schema_explicitly_typed() -> None:
|
||||||
"""This should test that one can type the args schema as a pydantic model.
|
"""This should test that one can type the args schema as a pydantic model."""
|
||||||
|
|
||||||
Please note that this will test using pydantic 2 even though BaseTool
|
|
||||||
is a pydantic 1 model!
|
|
||||||
"""
|
|
||||||
# Check with whatever pydantic model is passed in and not via v1 namespace
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
a: int
|
a: int
|
||||||
b: str
|
b: str
|
||||||
|
|
||||||
class SomeTool(BaseTool):
|
class SomeTool(BaseTool):
|
||||||
# type ignoring here since we're allowing overriding a type
|
|
||||||
# signature of pydantic.v1.BaseModel with pydantic.BaseModel
|
|
||||||
# for pydantic 2!
|
|
||||||
args_schema: type[BaseModel] = Foo
|
args_schema: type[BaseModel] = Foo
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||||
@ -1944,11 +1841,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
|||||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||||
|
|
||||||
args_schema = cast("BaseModel", foo_tool.args_schema)
|
args_schema = cast("BaseModel", foo_tool.args_schema)
|
||||||
args_json_schema = (
|
args_json_schema = args_schema.model_json_schema()
|
||||||
args_schema.model_json_schema()
|
|
||||||
if hasattr(args_schema, "model_json_schema")
|
|
||||||
else args_schema.schema()
|
|
||||||
)
|
|
||||||
assert args_json_schema == {
|
assert args_json_schema == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {"title": "A", "type": "integer"},
|
"a": {"title": "A", "type": "integer"},
|
||||||
@ -1960,11 +1853,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
|||||||
}
|
}
|
||||||
|
|
||||||
input_schema = foo_tool.get_input_schema()
|
input_schema = foo_tool.get_input_schema()
|
||||||
input_json_schema = (
|
input_json_schema = input_schema.model_json_schema()
|
||||||
input_schema.model_json_schema()
|
|
||||||
if hasattr(input_schema, "model_json_schema")
|
|
||||||
else input_schema.schema()
|
|
||||||
)
|
|
||||||
assert input_json_schema == {
|
assert input_json_schema == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {"title": "A", "type": "integer"},
|
"a": {"title": "A", "type": "integer"},
|
||||||
@ -2020,81 +1909,12 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
|
|||||||
assert _is_message_content_type(obj) is expected
|
assert _is_message_content_type(obj) is expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
def test__get_all_basemodel_annotations_v2() -> None:
|
||||||
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
|
||||||
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
|
|
||||||
A = TypeVar("A")
|
A = TypeVar("A")
|
||||||
|
|
||||||
if use_v1_namespace:
|
class ModelA(BaseModel, Generic[A]):
|
||||||
from pydantic.v1 import BaseModel as BaseModel1
|
|
||||||
|
|
||||||
class ModelA(BaseModel1, Generic[A], extra="allow"):
|
|
||||||
a: A
|
|
||||||
|
|
||||||
else:
|
|
||||||
from pydantic import BaseModel as BaseModel2
|
|
||||||
from pydantic import ConfigDict
|
|
||||||
|
|
||||||
class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef]
|
|
||||||
a: A
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
|
||||||
|
|
||||||
class ModelB(ModelA[str]):
|
|
||||||
b: Annotated[ModelA[dict[str, Any]], "foo"]
|
|
||||||
|
|
||||||
class Mixin:
|
|
||||||
def foo(self) -> str:
|
|
||||||
return "foo"
|
|
||||||
|
|
||||||
class ModelC(Mixin, ModelB):
|
|
||||||
c: dict
|
|
||||||
|
|
||||||
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict}
|
|
||||||
actual = get_all_basemodel_annotations(ModelC)
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"]}
|
|
||||||
actual = get_all_basemodel_annotations(ModelB)
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
expected = {"a": Any}
|
|
||||||
actual = get_all_basemodel_annotations(ModelA)
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
expected = {"a": int}
|
|
||||||
actual = get_all_basemodel_annotations(ModelA[int])
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
D = TypeVar("D", bound=Union[str, int])
|
|
||||||
|
|
||||||
class ModelD(ModelC, Generic[D]):
|
|
||||||
d: Optional[D]
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"a": str,
|
|
||||||
"b": Annotated[ModelA[dict[str, Any]], "foo"],
|
|
||||||
"c": dict,
|
|
||||||
"d": Union[str, int, None],
|
|
||||||
}
|
|
||||||
actual = get_all_basemodel_annotations(ModelD)
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"a": str,
|
|
||||||
"b": Annotated[ModelA[dict[str, Any]], "foo"],
|
|
||||||
"c": dict,
|
|
||||||
"d": Union[int, None],
|
|
||||||
}
|
|
||||||
actual = get_all_basemodel_annotations(ModelD[int])
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.")
|
|
||||||
def test__get_all_basemodel_annotations_v1() -> None:
|
|
||||||
A = TypeVar("A")
|
|
||||||
|
|
||||||
class ModelA(BaseModel, Generic[A], extra="allow"):
|
|
||||||
a: A
|
a: A
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||||
|
|
||||||
class ModelB(ModelA[str]):
|
class ModelB(ModelA[str]):
|
||||||
b: Annotated[ModelA[dict[str, Any]], "foo"]
|
b: Annotated[ModelA[dict[str, Any]], "foo"]
|
||||||
@ -2226,14 +2046,9 @@ def test_create_retriever_tool() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.")
|
def test_tool_args_schema_pydantic_with_metadata() -> None:
|
||||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
class Foo(BaseModel):
|
||||||
from pydantic import BaseModel as BaseModelV2
|
x: list[int] = Field(
|
||||||
from pydantic import Field as FieldV2
|
|
||||||
from pydantic import ValidationError as ValidationErrorV2
|
|
||||||
|
|
||||||
class Foo(BaseModelV2):
|
|
||||||
x: list[int] = FieldV2(
|
|
||||||
description="List of integers", min_length=10, max_length=15
|
description="List of integers", min_length=10, max_length=15
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2260,7 +2075,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert foo.invoke({"x": [0] * 10})
|
assert foo.invoke({"x": [0] * 10})
|
||||||
with pytest.raises(ValidationErrorV2):
|
with pytest.raises(ValidationError):
|
||||||
foo.invoke({"x": [0] * 9})
|
foo.invoke({"x": [0] * 9})
|
||||||
|
|
||||||
|
|
||||||
|
@ -746,7 +746,7 @@ def test_tool_outputs() -> None:
|
|||||||
[ExtensionsAnnotated, TypingAnnotated],
|
[ExtensionsAnnotated, TypingAnnotated],
|
||||||
ids=["typing_extensions.Annotated", "typing.Annotated"],
|
ids=["typing_extensions.Annotated", "typing.Annotated"],
|
||||||
)
|
)
|
||||||
def test__convert_typed_dict_to_openai_function(
|
def test_convert_typed_dict_to_openai_function(
|
||||||
typed_dict: TypeAlias, annotated: TypeAlias
|
typed_dict: TypeAlias, annotated: TypeAlias
|
||||||
) -> None:
|
) -> None:
|
||||||
class SubTool(typed_dict): # type: ignore[misc]
|
class SubTool(typed_dict): # type: ignore[misc]
|
||||||
@ -985,7 +985,7 @@ def test__convert_typed_dict_to_openai_function(
|
|||||||
@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
|
@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
|
||||||
def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
|
def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
|
||||||
class Tool(typed_dict): # type: ignore[misc]
|
class Tool(typed_dict): # type: ignore[misc]
|
||||||
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
|
arg1: typing.MutableSet
|
||||||
|
|
||||||
# Error should be raised since we're using v1 code path here
|
# Error should be raised since we're using v1 code path here
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
|
@ -3,18 +3,12 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
from pydantic import BaseModel, ConfigDict
|
||||||
from pydantic import ConfigDict
|
|
||||||
|
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
IS_PYDANTIC_V1,
|
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
PYDANTIC_VERSION,
|
|
||||||
_create_subset_model_v2,
|
_create_subset_model_v2,
|
||||||
create_model_v2,
|
create_model_v2,
|
||||||
get_fields,
|
get_fields,
|
||||||
is_basemodel_instance,
|
|
||||||
is_basemodel_subclass,
|
|
||||||
pre_init,
|
pre_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,52 +88,6 @@ def test_with_aliases() -> None:
|
|||||||
assert foo.z == 2
|
assert foo.z == 2
|
||||||
|
|
||||||
|
|
||||||
def test_is_basemodel_subclass() -> None:
|
|
||||||
"""Test pydantic."""
|
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
|
||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV1Proper)
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV2)
|
|
||||||
|
|
||||||
assert is_basemodel_subclass(BaseModelV1)
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_basemodel_instance() -> None:
|
|
||||||
"""Test pydantic."""
|
|
||||||
if IS_PYDANTIC_V1:
|
|
||||||
from pydantic import BaseModel as BaseModelV1Proper
|
|
||||||
|
|
||||||
class FooV1(BaseModelV1Proper):
|
|
||||||
x: int
|
|
||||||
|
|
||||||
assert is_basemodel_instance(FooV1(x=5))
|
|
||||||
elif IS_PYDANTIC_V2:
|
|
||||||
from pydantic import BaseModel as BaseModelV2
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
|
|
||||||
class Foo(BaseModelV2):
|
|
||||||
x: int
|
|
||||||
|
|
||||||
assert is_basemodel_instance(Foo(x=5))
|
|
||||||
|
|
||||||
class Bar(BaseModelV1):
|
|
||||||
x: int
|
|
||||||
|
|
||||||
assert is_basemodel_instance(Bar(x=5))
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
|
||||||
def test_with_field_metadata() -> None:
|
def test_with_field_metadata() -> None:
|
||||||
"""Test pydantic with field metadata."""
|
"""Test pydantic with field metadata."""
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
@ -168,21 +116,7 @@ def test_with_field_metadata() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1")
|
|
||||||
def test_fields_pydantic_v1() -> None:
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
|
||||||
x: int
|
|
||||||
|
|
||||||
fields = get_fields(Foo)
|
|
||||||
assert fields == {"x": Foo.model_fields["x"]}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
|
||||||
def test_fields_pydantic_v2_proper() -> None:
|
def test_fields_pydantic_v2_proper() -> None:
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
@ -190,17 +124,6 @@ def test_fields_pydantic_v2_proper() -> None:
|
|||||||
assert fields == {"x": Foo.model_fields["x"]}
|
assert fields == {"x": Foo.model_fields["x"]}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
|
|
||||||
def test_fields_pydantic_v1_from_2() -> None:
|
|
||||||
from pydantic.v1 import BaseModel
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
|
||||||
x: int
|
|
||||||
|
|
||||||
fields = get_fields(Foo)
|
|
||||||
assert fields == {"x": Foo.__fields__["x"]}
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_model_v2() -> None:
|
def test_create_model_v2() -> None:
|
||||||
"""Test that create model v2 works as expected."""
|
"""Test that create model v2 works as expected."""
|
||||||
with warnings.catch_warnings(record=True) as record:
|
with warnings.catch_warnings(record=True) as record:
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from langchain_core import utils
|
from langchain_core import utils
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
@ -16,10 +16,6 @@ from langchain_core.utils import (
|
|||||||
guard_import,
|
guard_import,
|
||||||
)
|
)
|
||||||
from langchain_core.utils._merge import merge_dicts
|
from langchain_core.utils._merge import merge_dicts
|
||||||
from langchain_core.utils.pydantic import (
|
|
||||||
IS_PYDANTIC_V1,
|
|
||||||
IS_PYDANTIC_V2,
|
|
||||||
)
|
|
||||||
from langchain_core.utils.utils import secret_from_env
|
from langchain_core.utils.utils import secret_from_env
|
||||||
|
|
||||||
|
|
||||||
@ -214,39 +210,7 @@ def test_guard_import_failure(
|
|||||||
guard_import(module_name, pip_name=pip_name, package=package)
|
guard_import(module_name, pip_name=pip_name, package=package)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
def test_get_pydantic_field_names() -> None:
|
||||||
def test_get_pydantic_field_names_v1_in_2() -> None:
|
|
||||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
|
||||||
from pydantic.v1 import Field
|
|
||||||
|
|
||||||
class PydanticV1Model(PydanticV1BaseModel):
|
|
||||||
field1: str
|
|
||||||
field2: int
|
|
||||||
alias_field: int = Field(alias="aliased_field")
|
|
||||||
|
|
||||||
result = get_pydantic_field_names(PydanticV1Model)
|
|
||||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
|
||||||
assert result == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
|
|
||||||
def test_get_pydantic_field_names_v2_in_2() -> None:
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class PydanticModel(BaseModel):
|
|
||||||
field1: str
|
|
||||||
field2: int
|
|
||||||
alias_field: int = Field(alias="aliased_field")
|
|
||||||
|
|
||||||
result = get_pydantic_field_names(PydanticModel)
|
|
||||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
|
||||||
assert result == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1")
|
|
||||||
def test_get_pydantic_field_names_v1() -> None:
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class PydanticModel(BaseModel):
|
class PydanticModel(BaseModel):
|
||||||
field1: str
|
field1: str
|
||||||
field2: int
|
field2: int
|
||||||
|
Loading…
Reference in New Issue
Block a user