Compare commits

...

99 Commits

Author SHA1 Message Date
Eugene Yurtsev
41b17ee262 REMOVE CALLBACKS IN BASE CHAT MODEL 2024-08-16 14:08:56 -04:00
Eugene Yurtsev
c240015491 Merge changes 2024-08-16 14:02:12 -04:00
Eugene Yurtsev
658e0f471e fix pydantic.py merge _get_fields -> get_fields 2024-08-16 13:58:44 -04:00
Eugene Yurtsev
33445d4bbd Merge branch 'master' into eugene/merge_pydantic_3_changes 2024-08-16 13:51:44 -04:00
Eugene Yurtsev
ac29cdfd81 Update callback manager 2024-08-09 11:23:24 -04:00
Eugene Yurtsev
9230ee402c Merge changes 2024-08-09 11:19:01 -04:00
Eugene Yurtsev
4b5d57cd1c Merge branch 'master' into eugene/merge_pydantic_3_changes 2024-08-09 11:17:06 -04:00
Eugene Yurtsev
3d561c3e6d model_validate_json 2024-08-08 11:34:23 -04:00
Eugene Yurtsev
09f9d3e972 get_fields 2024-08-08 11:31:57 -04:00
Eugene Yurtsev
c1e6e7d020 get_fields 2024-08-08 11:31:54 -04:00
Eugene Yurtsev
6f79443ab5 get_fields 2024-08-08 11:31:41 -04:00
Eugene Yurtsev
58c4e1ef86 Add name to Runnable Generator 2024-08-08 11:20:08 -04:00
Eugene Yurtsev
de30b04f37 Fix type issue 2024-08-08 11:12:31 -04:00
Eugene Yurtsev
f7a455299e Add more types 2024-08-07 22:12:29 -04:00
Eugene Yurtsev
76043abd47 add unit test 2024-08-07 17:23:54 -04:00
Eugene Yurtsev
61cdb9ccce Fix output schema 2024-08-07 17:23:44 -04:00
Eugene Yurtsev
1ef3fa54fc Fix types 2024-08-07 17:16:28 -04:00
Eugene Yurtsev
3856e3b02a Fix types 2024-08-07 17:14:59 -04:00
Eugene Yurtsev
035f09f20d Fix types 2024-08-07 17:14:47 -04:00
Eugene Yurtsev
63f7a5ab68 Replace __fields__ with model_fields 2024-08-07 17:11:14 -04:00
Eugene Yurtsev
8447d9f6f1 Replace __fields__ with model_fields 2024-08-07 17:10:07 -04:00
Eugene Yurtsev
95db9e9258 Replace __fields__ with model_fields 2024-08-07 17:07:15 -04:00
Eugene Yurtsev
0d1b93774b Resolve linting 2024-08-07 17:06:46 -04:00
Eugene Yurtsev
bcce3a2865 Resolve linting 2024-08-07 17:03:30 -04:00
Eugene Yurtsev
4a478d82bd Resolve linting 2024-08-07 17:02:15 -04:00
Eugene Yurtsev
a0c3657442 Resolve linting 2024-08-07 17:01:29 -04:00
Eugene Yurtsev
72c5c28b4d Resolve linting 2024-08-07 17:01:02 -04:00
Eugene Yurtsev
fe6f2f724b Use model_fields 2024-08-07 16:08:30 -04:00
Eugene Yurtsev
88d347e90c Remove pydantic lint in core 2024-08-07 16:06:40 -04:00
Eugene Yurtsev
741b50d4fd Fix serializer 2024-08-07 15:57:57 -04:00
Eugene Yurtsev
24c6825345 Fix serialization test 2024-08-07 15:57:41 -04:00
Eugene Yurtsev
32824aa55c Handle lint 2024-08-07 15:47:56 -04:00
Eugene Yurtsev
f6924653ea Handle lint 2024-08-07 15:47:51 -04:00
Eugene Yurtsev
66e8594b89 Handle lint 2024-08-07 15:46:37 -04:00
Eugene Yurtsev
3b9f061eac Handle lint 2024-08-07 15:45:21 -04:00
Eugene Yurtsev
76b6ee290d Replace __fields__ with model_fields 2024-08-07 15:44:31 -04:00
Eugene Yurtsev
22957311fe Add more tests for serializable 2024-08-07 15:40:59 -04:00
Eugene Yurtsev
f9df75c8cc Add more tests for serializable 2024-08-07 15:37:21 -04:00
Eugene Yurtsev
ece0ab8539 Add more tests for serializable 2024-08-07 15:18:16 -04:00
Eugene Yurtsev
4ddd9e5f23 lint 2024-08-07 15:05:52 -04:00
Eugene Yurtsev
f8e95e5735 lint 2024-08-07 15:04:02 -04:00
Eugene Yurtsev
6515b2f77b Linting fixes 2024-08-07 15:03:31 -04:00
Eugene Yurtsev
63fde4f095 Linting fixes 2024-08-07 13:59:22 -04:00
Eugene Yurtsev
d9bb9125c1 Linting fixes 2024-08-07 13:55:56 -04:00
Eugene Yurtsev
384d9f59a3 Linting fixes 2024-08-07 13:55:38 -04:00
Eugene Yurtsev
fc0fa7e8f0 Add missing import 2024-08-07 13:50:14 -04:00
Eugene Yurtsev
a1054d06ca Add missing import 2024-08-07 13:48:43 -04:00
Eugene Yurtsev
c2570a7a7c lint 2024-08-07 13:47:43 -04:00
Eugene Yurtsev
97f4128bfd Add missing imports 2024-08-07 13:47:26 -04:00
Eugene Yurtsev
2434dc8f92 update snapshots 2024-08-07 13:46:18 -04:00
Eugene Yurtsev
123d61a888 Add missing imports 2024-08-07 13:43:44 -04:00
Eugene Yurtsev
53f6f4a0c0 Mark explicitly with # pydantic: ignore 2024-08-07 13:41:17 -04:00
Eugene Yurtsev
550bef230a Merge branch 'master' into eugene/merge_pydantic_3_changes 2024-08-07 13:28:46 -04:00
Eugene Yurtsev
5a998d36b2 Convert to v1 model for now 2024-08-07 12:09:42 -04:00
Eugene Yurtsev
72cd199efc Fix create_subset_model_v1 2024-08-07 11:58:10 -04:00
Eugene Yurtsev
a1d993deb1 Remove deprecated comment 2024-08-07 11:54:21 -04:00
Eugene Yurtsev
e546e21d53 Update unit test for pydantic 2 2024-08-07 11:52:28 -04:00
Eugene Yurtsev
26d6426156 Fix extra space in repr 2024-08-07 11:48:11 -04:00
Eugene Yurtsev
8dffedebd6 Add Skip Validation() 2024-08-07 11:38:28 -04:00
Eugene Yurtsev
60adf8d6e4 Handle is_injected_arg_type 2024-08-07 11:36:56 -04:00
Eugene Yurtsev
d13a1ad5f5 Use _AnyIDDocument 2024-08-07 11:27:35 -04:00
Eugene Yurtsev
1e5f8a494a Add SkipValidation() 2024-08-07 11:25:21 -04:00
Eugene Yurtsev
5216131769 Fixed something? 2024-08-07 11:18:52 -04:00
Eugene Yurtsev
8bdaf858b8 Use is_basemodel_instance 2024-08-07 11:03:53 -04:00
Eugene Yurtsev
c37a0ca672 Use is_basemodel_subclass 2024-08-07 11:03:35 -04:00
Eugene Yurtsev
266cd15511 ADd Skip Validation 2024-08-07 10:51:43 -04:00
Eugene Yurtsev
9debf8144e ADd Skip Validation 2024-08-07 10:51:02 -04:00
Eugene Yurtsev
78ce0ed337 Fix broken type 2024-08-07 10:23:55 -04:00
Eugene Yurtsev
4aa1932bea update 2024-08-07 09:52:33 -04:00
Eugene Yurtsev
b658295b97 update 2024-08-07 09:40:29 -04:00
Eugene Yurtsev
8c59b6a026 Merge fix 2024-08-07 09:32:25 -04:00
Eugene Yurtsev
e35b43a7a7 Fix ConfigDict to be populate by name 2024-08-07 09:16:23 -04:00
Eugene Yurtsev
7288d914a8 Add missing model rebuild and optional 2024-08-07 09:14:06 -04:00
Eugene Yurtsev
1b487e261a add missing pydantic import 2024-08-07 09:04:36 -04:00
Eugene Yurtsev
3934663db9 Merge branch 'master' into eugene/merge_pydantic_3_changes 2024-08-07 08:59:28 -04:00
Eugene Yurtsev
fb639cb49c lint 2024-08-06 22:02:31 -04:00
Eugene Yurtsev
1856387e9e Add missing imports load and dumpd 2024-08-06 17:10:42 -04:00
Eugene Yurtsev
a5ad775a90 Add Optional import 2024-08-06 17:10:18 -04:00
Eugene Yurtsev
a321401683 Update pydantic utility 2024-08-06 16:54:55 -04:00
Eugene Yurtsev
8839220a00 Restore more missing stuff 2024-08-06 16:10:59 -04:00
Eugene Yurtsev
e6b2ca4da3 x 2024-08-06 16:08:06 -04:00
Eugene Yurtsev
d0c52d1dec x 2024-08-06 16:06:44 -04:00
Eugene Yurtsev
a5fa6d1c43 x 2024-08-06 16:05:43 -04:00
Eugene Yurtsev
7f79bd6e04 x 2024-08-06 16:04:14 -04:00
Eugene Yurtsev
339985e39e merge more 2024-08-06 15:59:53 -04:00
Eugene Yurtsev
f4ecd749d5 x 2024-08-06 15:58:55 -04:00
Eugene Yurtsev
cb61c6b4bf Merge branch 'master' into eugene/merge_pydantic_3_changes 2024-08-06 15:57:37 -04:00
Eugene Yurtsev
b42c2c6cd6 Update to master 2024-08-06 15:57:35 -04:00
Eugene Yurtsev
da6633bf0d update 2024-08-06 13:08:53 -04:00
Eugene Yurtsev
0193d18bec update 2024-08-06 13:04:17 -04:00
Eugene Yurtsev
0a82192e36 update forward refs 2024-08-06 12:41:52 -04:00
Eugene Yurtsev
202f6fef95 update 2024-08-06 12:39:00 -04:00
Eugene Yurtsev
c49416e908 fix typo 2024-08-06 12:35:05 -04:00
Eugene Yurtsev
ec93ea6240 update 2024-08-06 12:33:43 -04:00
Eugene Yurtsev
add20dc9a8 update 2024-08-06 12:30:33 -04:00
Eugene Yurtsev
7799474746 MANUAL: May need to revert 2024-08-06 11:47:27 -04:00
Eugene Yurtsev
d98c1f115f update 2024-08-06 11:46:39 -04:00
Eugene Yurtsev
d97f70def4 Update 2024-08-06 11:43:25 -04:00
Eugene Yurtsev
609c6b0963 Update 2024-08-06 11:40:43 -04:00
47 changed files with 6836 additions and 4359 deletions

View File

@@ -39,7 +39,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff check .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff

View File

@@ -4,10 +4,12 @@ import contextlib
import mimetypes
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Generator, List, Literal, Mapping, Optional, Union, cast
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, root_validator
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils.pydantic import v1_repr
PathLike = Union[str, PurePath]
@@ -110,9 +112,10 @@ class Blob(BaseMedia):
path: Optional[PathLike] = None
"""Location where the original content was found."""
class Config:
arbitrary_types_allowed = True
frozen = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
frozen=True,
)
@property
def source(self) -> Optional[str]:
@@ -128,7 +131,7 @@ class Blob(BaseMedia):
return str(self.path) if self.path else None
@root_validator(pre=True)
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided")
@@ -293,3 +296,7 @@ class Document(BaseMedia):
return f"page_content='{self.page_content}' metadata={self.metadata}"
else:
return f"page_content='{self.page_content}'"
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)

View File

@@ -95,7 +95,11 @@ class BaseLanguageModel(
Caching is not currently supported for streaming methods of models.
"""
verbose: bool = Field(default_factory=_get_verbosity)
# Repr = False is consistent with pydantic 1 if verbose = False
# We can relax this for pydantic 2?
# TODO(Team): decide what to do here.
# Modified just to get unit tests to pass.
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
@@ -108,6 +112,9 @@ class BaseLanguageModel(
)
"""Optional encoder to use for counting tokens."""
class Config:
arbitrary_types_allowed = True
@validator("verbose", pre=True, always=True, allow_reuse=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.

View File

@@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Annotated,
Any,
AsyncIterator,
Callable,
@@ -23,6 +24,13 @@ from typing import (
cast,
)
from pydantic import (
BaseModel,
ConfigDict,
Field,
SkipValidation,
root_validator,
)
from typing_extensions import TypedDict
from langchain_core._api import deprecated
@@ -55,11 +63,6 @@ from langchain_core.outputs import (
RunInfo,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
root_validator,
)
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
@@ -208,16 +211,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
""" # noqa: E501
callback_manager: Optional[BaseCallbackManager] = deprecated(
name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
)(
Field(
default=None,
exclude=True,
description="Callback manager to add to the run trace.",
)
)
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
"An optional rate limiter to use for limiting the number of requests."
@@ -254,8 +247,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
# --- Runnable methods ---

View File

@@ -10,9 +10,10 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict # pydantic: ignore
from typing_extensions import NotRequired
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import v1_repr
class BaseSerialized(TypedDict):
@@ -80,7 +81,7 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
Exception: If the key is not in the model.
"""
try:
return model.__fields__[key].get_default() != value
return model.model_fields[key].get_default() != value
except Exception:
return True
@@ -161,16 +162,25 @@ class Serializable(BaseModel, ABC):
For example, for the class `langchain.llms.openai.OpenAI`, the id is
["langchain", "llms", "openai", "OpenAI"].
"""
return [*cls.get_lc_namespace(), cls.__name__]
# Pydantic generics change the class name. So we need to do the following
if (
"origin" in cls.__pydantic_generic_metadata__
and cls.__pydantic_generic_metadata__["origin"] is not None
):
original_name = cls.__pydantic_generic_metadata__["origin"].__name__
else:
original_name = cls.__name__
return [*cls.get_lc_namespace(), original_name]
class Config:
extra = "ignore"
model_config = ConfigDict(
extra="ignore",
)
def __repr_args__(self) -> Any:
return [
(k, v)
for k, v in super().__repr_args__()
if (k not in self.__fields__ or try_neq_default(v, k, self))
if (k not in self.model_fields or try_neq_default(v, k, self))
]
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
@@ -184,12 +194,15 @@ class Serializable(BaseModel, ABC):
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
and _is_field_useful(self, k, v)
}
lc_kwargs = {}
for k, v in self:
if not _is_field_useful(self, k, v):
continue
# Do nothing if the field is excluded
if k in self.model_fields and self.model_fields[k].exclude:
continue
lc_kwargs[k] = getattr(self, k, v)
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
@@ -221,8 +234,10 @@ class Serializable(BaseModel, ABC):
# that are not present in the fields.
for key in list(secrets):
value = secrets[key]
if key in this.__fields__:
secrets[this.__fields__[key].alias] = value
if key in this.model_fields:
alias = this.model_fields[key].alias
if alias is not None:
secrets[alias] = value
lc_kwargs.update(this.lc_attributes)
# include all secrets, even if not specified in kwargs
@@ -244,6 +259,10 @@ class Serializable(BaseModel, ABC):
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
"""Check if a field is useful as a constructor argument.
@@ -259,10 +278,26 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
If the field is not required and the value is None, it is useful if the
default value is different from the value.
"""
field = inst.__fields__.get(key)
field = inst.model_fields.get(key)
if not field:
return False
return field.required is True or value or field.get_default() != value
if field.is_required():
return True
if value:
return True
# Value is still falsy here!
if field.default_factory is dict and isinstance(value, dict):
return False
# Value is still falsy here!
if field.default_factory is list and isinstance(value, list):
return False
# If value is falsy and does not match the default
return field.get_default() != value
def _replace_secrets(

View File

@@ -7,6 +7,7 @@ from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.utils import get_bolded_text
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.interactive_env import is_interactive_env
from langchain_core.utils.pydantic import v1_repr
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
@@ -108,6 +109,10 @@ class BaseMessage(Serializable):
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing_extensions import NotRequired, TypedDict
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
from langchain_core.pydantic_v1 import Field
from langchain_core.utils._merge import merge_dicts, merge_obj
@@ -70,6 +71,11 @@ class ToolMessage(BaseMessage):
.. versionadded:: 0.2.24
"""
additional_kwargs: dict = Field(default_factory=dict, repr=False)
"""Currently inherited from BaseMessage, but not used."""
response_metadata: dict = Field(default_factory=dict, repr=False)
"""Currently inherited from BaseMessage, but not used."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.

View File

@@ -13,8 +13,6 @@ from typing import (
Union,
)
from typing_extensions import get_args
from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
@@ -166,10 +164,11 @@ class BaseOutputParser(
Raises:
TypeError: If the class doesn't have an inferable OutputType.
"""
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:
return type_args[0]
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) > 0:
return metadata["args"][0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import json
from json import JSONDecodeError
from typing import Any, List, Optional, Type, TypeVar, Union
from typing import Annotated, Any, List, Optional, Type, TypeVar, Union
import jsonpatch # type: ignore[import]
import pydantic # pydantic: ignore
from pydantic import SkipValidation # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
@@ -40,7 +41,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object.
"""
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore
"""The Pydantic object to use for validation.
If None, no validation is performed."""

View File

@@ -4,6 +4,7 @@ import re
from abc import abstractmethod
from collections import deque
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
from typing import Optional as Optional
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -122,6 +123,9 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
yield [part]
ListOutputParser.update_forward_refs()
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""

View File

@@ -3,6 +3,7 @@ import json
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch # type: ignore[import]
from pydantic import BaseModel, root_validator
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
@@ -11,7 +12,6 @@ from langchain_core.output_parsers import (
)
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, root_validator
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
@@ -263,11 +263,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
if hasattr(self.pydantic_schema, "model_validate_json"):
pydantic_args = self.pydantic_schema.model_validate_json(_result)
else:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
else:
fn_name = _result["name"]
_args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
if hasattr(self.pydantic_schema, "model_validate_json"):
pydantic_args = self.pydantic_schema[fn_name].model_validate_json(_args)
else:
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
return pydantic_args

View File

@@ -1,7 +1,9 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, Dict, List, Optional
from typing import Annotated, Any, Dict, List, Optional
from pydantic import SkipValidation, ValidationError # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
@@ -13,7 +15,6 @@ from langchain_core.messages.tool import (
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import ValidationError
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel
@@ -256,7 +257,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: List[TypeBaseModel]
tools: Annotated[List[TypeBaseModel], SkipValidation()]
"""The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all

View File

@@ -1,7 +1,9 @@
import json
from typing import Generic, List, Type
from typing import Annotated, Generic, List, Type
from typing import Optional as Optional
import pydantic # pydantic: ignore
from pydantic import SkipValidation # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
@@ -16,7 +18,7 @@ from langchain_core.utils.pydantic import (
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model."""
pydantic_object: Type[TBaseModel] # type: ignore
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
"""The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
@@ -106,6 +108,9 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return self.pydantic_object
PydanticOutputParser.model_rebuild()
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}

View File

@@ -1,4 +1,5 @@
from typing import List
from typing import Optional as Optional
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -24,3 +25,6 @@ class StrOutputParser(BaseTransformOutputParser[str]):
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
StrOutputParser.model_rebuild()

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional
from typing import List, Optional, Union
from langchain_core.outputs.generation import Generation
from pydantic import BaseModel
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
from langchain_core.outputs.generation import Generation, GenerationChunk
from langchain_core.outputs.run_info import RunInfo
from langchain_core.pydantic_v1 import BaseModel
class LLMResult(BaseModel):
@@ -16,7 +18,9 @@ class LLMResult(BaseModel):
wants to return.
"""
generations: List[List[Generation]]
generations: List[
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
]
"""Generated outputs.
The first dimension of the list represents completions for different input

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Annotated,
Any,
Dict,
List,
@@ -21,6 +22,13 @@ from typing import (
overload,
)
from pydantic import (
Field,
PositiveInt,
SkipValidation,
root_validator,
)
from langchain_core._api import deprecated
from langchain_core.load import Serializable
from langchain_core.messages import (
@@ -38,7 +46,6 @@ from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
@@ -207,8 +214,14 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
def __init__(
self, variable_name: str, *, optional: bool = False, **kwargs: Any
) -> None:
# mypy can't detect the init which is defined in the parent class
# b/c these are BaseModel classes.
super().__init__( # type: ignore
variable_name=variable_name, optional=optional, **kwargs
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
@@ -922,7 +935,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
""" # noqa: E501
messages: List[MessageLike]
messages: Annotated[List[MessageLike], SkipValidation]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""

View File

@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Tuple
from typing import Optional as Optional
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.base import BasePromptTemplate
@@ -106,3 +107,6 @@ class PipelinePromptTemplate(BasePromptTemplate):
@property
def _prompt_type(self) -> str:
raise ValueError
PipelinePromptTemplate.update_forward_refs()

View File

@@ -35,7 +35,8 @@ from typing import (
overload,
)
from typing_extensions import Literal, get_args
from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args, get_type_hints
from langchain_core._api import beta_decorator
from langchain_core.load.dump import dumpd
@@ -44,7 +45,6 @@ from langchain_core.load.serializable import (
SerializedConstructor,
SerializedNotImplemented,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
@@ -83,7 +83,6 @@ from langchain_core.runnables.utils import (
)
from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@@ -236,25 +235,56 @@ class Runnable(Generic[Input, Output], ABC):
For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/
""" # noqa: E501
name: Optional[str] = None
name: Optional[str]
"""The name of the Runnable. Used for debugging and tracing."""
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
"""Get the name of the Runnable."""
name = name or self.name or self.__class__.__name__
if suffix:
if name[0].isupper():
return name + suffix.title()
else:
return name + "_" + suffix.lower()
if name:
name_ = name
elif hasattr(self, "name") and self.name:
name_ = self.name
else:
return name
# Here we handle a case where the runnable subclass is also a pydantic
# model.
cls = self.__class__
# Then it's a pydantic sub-class, and we have to check
# whether it's a generic, and if so recover the original name.
if (
hasattr(
cls,
"__pydantic_generic_metadata__",
)
and "origin" in cls.__pydantic_generic_metadata__
and cls.__pydantic_generic_metadata__["origin"] is not None
):
name_ = cls.__pydantic_generic_metadata__["origin"].__name__
else:
name_ = cls.__name__
if suffix:
if name_[0].isupper():
return name_ + suffix.title()
else:
return name_ + "_" + suffix.lower()
else:
return name_
@property
def InputType(self) -> Type[Input]:
"""The type of input this Runnable accepts specified as a type annotation."""
# First loop through bases -- this will help generic
# any pydantic models.
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) == 2:
return metadata["args"][0]
# then loop through __orig_bases__ -- this will Runnables that do not inherit
# from pydantic
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
@@ -268,6 +298,14 @@ class Runnable(Generic[Input, Output], ABC):
@property
def OutputType(self) -> Type[Output]:
"""The type of output this Runnable produces specified as a type annotation."""
# First loop through bases -- this will help generic
# any pydantic models.
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) == 2:
return metadata["args"][1]
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
@@ -302,12 +340,12 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.InputType
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Input"),
__root__=(root_type, None),
__root__=root_type,
)
@property
@@ -334,12 +372,12 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.OutputType
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Output"),
__root__=(root_type, None),
__root__=root_type,
)
@property
@@ -381,15 +419,19 @@ class Runnable(Generic[Input, Output], ABC):
else None
)
return create_model( # type: ignore[call-overload]
self.get_name("Config"),
# Many need to create a typed dict instead to implement NotRequired!
all_fields = {
**({"configurable": (configurable, None)} if configurable else {}),
**{
field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items()
for field_name, field_type in get_type_hints(RunnableConfig).items()
if field_name in [i for i in include if i != "configurable"]
},
}
model = create_model( # type: ignore[call-overload]
self.get_name("Config"), **all_fields
)
return model
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
"""Return a graph representation of this Runnable."""
@@ -579,7 +621,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
from langchain_core.runnables.passthrough import RunnableAssign
return self | RunnableAssign(RunnableParallel(kwargs))
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
""" --- Public API --- """
@@ -2129,7 +2171,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
iterator_ = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
if accepts_config(transformer):
@@ -2314,7 +2355,6 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
"""Runnable that can be serialized to JSON."""
name: Optional[str] = None
"""The name of the Runnable. Used for debugging and tracing."""
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
"""Serialize the Runnable to JSON.
@@ -2369,10 +2409,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
from langchain_core.runnables.configurable import RunnableConfigurableFields
for key in kwargs:
if key not in self.__fields__:
if key not in self.model_fields:
raise ValueError(
f"Configuration key {key} not found in {self}: "
f"available keys are {self.__fields__.keys()}"
f"available keys are {self.model_fields.keys()}"
)
return RunnableConfigurableFields(default=self, fields=kwargs)
@@ -2447,13 +2487,13 @@ def _seq_input_schema(
return first.get_input_schema(config)
elif isinstance(first, RunnableAssign):
next_input_schema = _seq_input_schema(steps[1:], config)
if not next_input_schema.__custom_root_type__:
if not issubclass(next_input_schema, RootModel):
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceInput",
**{
k: (v.annotation, v.default)
for k, v in next_input_schema.__fields__.items()
for k, v in next_input_schema.model_fields.items()
if k not in first.mapper.steps__
},
)
@@ -2474,36 +2514,36 @@ def _seq_output_schema(
elif isinstance(last, RunnableAssign):
mapper_output_schema = last.mapper.get_output_schema(config)
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
if not issubclass(prev_output_schema, RootModel):
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
**{
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
for k, v in prev_output_schema.model_fields.items()
},
**{
k: (v.annotation, v.default)
for k, v in mapper_output_schema.__fields__.items()
for k, v in mapper_output_schema.model_fields.items()
},
},
)
elif isinstance(last, RunnablePick):
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
if not issubclass(prev_output_schema, RootModel):
# it's a dict as expected
if isinstance(last.keys, list):
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
for k, v in prev_output_schema.model_fields.items()
if k in last.keys
},
)
else:
field = prev_output_schema.__fields__[last.keys]
field = prev_output_schema.model_fields[last.keys]
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
__root__=(field.annotation, field.default),
@@ -2665,8 +2705,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
"""
return True
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Type[Input]:
@@ -3403,8 +3444,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
@@ -3451,7 +3493,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
**{
k: (v.annotation, v.default)
for step in self.steps__.values()
for k, v in step.get_input_schema(config).__fields__.items()
for k, v in step.get_input_schema(config).model_fields.items()
if k != "__root__"
},
)
@@ -3469,11 +3511,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
Returns:
The output schema of the Runnable.
"""
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
self.get_name("Output"),
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
)
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
return create_model(self.get_name("Output"), **fields)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
@@ -3883,6 +3922,8 @@ class RunnableGenerator(Runnable[Input, Output]):
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
*,
name: Optional[str] = None,
) -> None:
"""Initialize a RunnableGenerator.
@@ -3910,9 +3951,9 @@ class RunnableGenerator(Runnable[Input, Output]):
)
try:
self.name = func_for_name.__name__
self.name = name or func_for_name.__name__
except AttributeError:
pass
self.name = "RunnableGenerator"
@property
def InputType(self) -> Any:
@@ -4184,15 +4225,13 @@ class RunnableLambda(Runnable[Input, Output]):
if all(
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
):
fields = {item[1:-1]: (Any, ...) for item in items}
# It's a dict, lol
return create_model(
self.get_name("Input"),
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
)
return create_model(self.get_name("Input"), **fields)
else:
return create_model(
self.get_name("Input"),
__root__=(List[Any], None),
__root__=List[Any],
)
if self.InputType != Any:
@@ -4201,7 +4240,7 @@ class RunnableLambda(Runnable[Input, Output]):
if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
self.get_name("Input"),
**{key: (Any, None) for key in dict_keys}, # type: ignore
**{key: (Any, ...) for key in dict_keys}, # type: ignore
)
return super().get_input_schema(config)
@@ -4730,8 +4769,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Any:
@@ -4758,10 +4798,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
schema = self.bound.get_output_schema(config)
return create_model(
self.get_name("Output"),
__root__=(
List[schema], # type: ignore
None,
),
__root__=List[schema], # type: ignore[valid-type]
)
@property
@@ -4981,8 +5018,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
The type can be a pydantic model, or a type annotation (e.g., `List[str]`).
"""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def __init__(
self,
@@ -5318,7 +5356,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
yield item
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
RunnableBindingBase.model_rebuild()
class RunnableBinding(RunnableBindingBase[Input, Output]):

View File

@@ -134,7 +134,17 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_) # type: ignore[call-arg]
super().__init__(
branches=_branches,
default=default_,
# Hard-coding a name here because RunnableBranch is a generic
# and with pydantic 2, the class name with pydantic will capture
# include the parameterized type, which is not what we want.
# e.g., we'd get RunnableBranch[Input, Output] instead of RunnableBranch
# for the name. This information is already captured in the
# input and output types.
name="RunnableBranch",
) # type: ignore[call-arg]
class Config:
arbitrary_types_allowed = True

View File

@@ -20,7 +20,8 @@ from typing import (
)
from weakref import WeakValueDictionary
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, ConfigDict
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
@@ -58,8 +59,9 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def is_lc_serializable(cls) -> bool:
@@ -373,28 +375,33 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
Returns:
List[ConfigurableFieldSpec]: The configuration specs.
"""
return get_unique_config_specs(
[
(
# TODO(0.3): This change removes field_info which isn't needed in pydantic 2
config_specs = []
for field_name, spec in self.fields.items():
if isinstance(spec, ConfigurableField):
config_specs.append(
ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description
or self.default.__fields__[field_name].field_info.description,
or self.default.model_fields[field_name].description,
annotation=spec.annotation
or self.default.__fields__[field_name].annotation,
or self.default.model_fields[field_name].annotation,
default=getattr(self.default, field_name),
is_shared=spec.is_shared,
)
if isinstance(spec, ConfigurableField)
else make_options_spec(
spec, self.default.__fields__[field_name].field_info.description
)
else:
config_specs.append(
make_options_spec(
spec, self.default.model_fields[field_name].description
)
)
for field_name, spec in self.fields.items()
]
+ list(self.default.config_specs)
)
config_specs.extend(self.default.config_specs)
return get_unique_config_specs(config_specs)
def configurable_fields(
self, **kwargs: AnyConfigurableField
@@ -436,7 +443,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
init_params = {
k: v
for k, v in self.default.__dict__.items()
if k in self.default.__fields__
if k in self.default.model_fields
}
return (
self.default.__class__(**{**init_params, **configurable}),

View File

@@ -23,7 +23,7 @@ from typing import (
from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
@@ -235,7 +235,9 @@ def node_data_json(
json = (
{
"type": "schema",
"data": node.data.schema(),
"data": node.data.model_json_schema(
schema_generator=_IgnoreUnserializable
),
}
if with_schemas
else {

View File

@@ -372,28 +372,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
super_schema = super().get_input_schema(config)
if super_schema.__custom_root_type__ or not super_schema.schema().get(
"properties"
):
from langchain_core.messages import BaseMessage
# TODO(0.3): Verify that this change was correct
# Not enough tests and unclear on why the previous implementation was
# necessary.
from langchain_core.messages import BaseMessage
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
fields["__root__"] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
return super_schema
fields["__root__"] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
)
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return False

View File

@@ -21,7 +21,8 @@ from typing import (
cast,
)
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, RootModel
from langchain_core.runnables.base import (
Other,
Runnable,
@@ -227,7 +228,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
A Runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableParallel(kwargs))
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
@@ -419,7 +420,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
if not map_input_schema.__custom_root_type__:
if not issubclass(map_input_schema, RootModel):
# ie. it's a dict
return map_input_schema
@@ -430,20 +431,22 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
if not issubclass(map_input_schema, RootModel) and not issubclass(
map_output_schema, RootModel
):
# ie. both are dicts
fields = {}
for name, field_info in map_input_schema.model_fields.items():
fields[name] = (field_info.annotation, field_info.default)
for name, field_info in map_output_schema.model_fields.items():
fields[name] = (field_info.annotation, field_info.default)
return create_model( # type: ignore[call-overload]
"RunnableAssignOutput",
**{
k: (v.type_, v.default)
for s in (map_input_schema, map_output_schema)
for k, v in s.__fields__.items()
},
**fields,
)
elif not map_output_schema.__custom_root_type__:
elif not issubclass(map_output_schema, RootModel):
# ie. only map output is a dict
# ie. input type is either unknown or inferred incorrectly
return map_output_schema

View File

@@ -28,12 +28,18 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
from pydantic import BaseModel, ConfigDict, RootModel # pydantic: ignore
from pydantic import create_model as _create_model_base # pydantic :ignore
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)
from typing_extensions import TypeGuard
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
from langchain_core.pydantic_v1 import create_model as _create_model_base
from langchain_core.runnables.schema import StreamEvent
Input = TypeVar("Input", contravariant=True)
@@ -350,7 +356,7 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
tree = ast.parse(textwrap.dedent(code))
visitor = IsFunctionArgDict()
visitor.visit(tree)
return list(visitor.keys) if visitor.keys else None
return sorted(visitor.keys) if visitor.keys else None
except (SyntaxError, TypeError, OSError, SystemError):
return None
@@ -697,9 +703,57 @@ class _RootEventFilter:
return include
class _SchemaConfig(BaseConfig):
arbitrary_types_allowed = True
frozen = True
_SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
NO_DEFAULT = object()
def create_base_class(
name: str, type_: Any, default_: object = NO_DEFAULT
) -> Type[BaseModel]:
"""Create a base class."""
def schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_
def model_json_schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_
base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
"__module__": "langchain_core.runnables.utils",
}
if default_ is not NO_DEFAULT:
base_class_attributes["root"] = default_
custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(Type[BaseModel], custom_root_type)
def create_model(
@@ -715,6 +769,21 @@ def create_model(
Returns:
Type[BaseModel]: The created model.
"""
# Move this to caching path
if "__root__" in field_definitions:
if len(field_definitions) > 1:
raise NotImplementedError(
"When specifying __root__ no other "
f"fields should be provided. Got {field_definitions}"
)
arg = field_definitions["__root__"]
if isinstance(arg, tuple):
named_root_model = create_base_class(__model_name, arg[0], arg[1])
else:
named_root_model = create_base_class(__model_name, arg)
return named_root_model
try:
return _create_model_cached(__model_name, **field_definitions)
except TypeError:

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Sequence, Union
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class Visitor(ABC):
@@ -127,7 +127,8 @@ class Comparison(FilterDirective):
def __init__(
self, comparator: Comparator, attribute: str, value: Any, **kwargs: Any
) -> None:
super().__init__(
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
comparator=comparator, attribute=attribute, value=value, **kwargs
)
@@ -145,8 +146,11 @@ class Operation(FilterDirective):
def __init__(
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any
):
super().__init__(operator=operator, arguments=arguments, **kwargs)
) -> None:
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
operator=operator, arguments=arguments, **kwargs
)
class StructuredQuery(Expr):
@@ -165,5 +169,8 @@ class StructuredQuery(Expr):
filter: Optional[FilterDirective],
limit: Optional[int] = None,
**kwargs: Any,
):
super().__init__(query=query, filter=filter, limit=limit, **kwargs)
) -> None:
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
query=query, filter=filter, limit=limit, **kwargs
)

View File

@@ -24,7 +24,9 @@ from typing import (
get_type_hints,
)
from typing_extensions import Annotated, TypeVar, get_args, get_origin
==== BASE ====
from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin
==== BASE ====
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@@ -33,16 +35,26 @@ from langchain_core.callbacks import (
CallbackManager,
Callbacks,
)
from langchain_core.load import Serializable
from langchain_core.messages import ToolCall, ToolMessage
==== BASE ====
from langchain_core.load.serializable import Serializable
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.prompts import (
BasePromptTemplate,
PromptTemplate,
aformat_document,
format_document,
)
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
ValidationError,
create_model,
root_validator,
validate_arguments,
)
from langchain_core.retrievers import BaseRetriever
==== BASE ====
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
@@ -59,6 +71,7 @@ from langchain_core.utils.function_calling import (
from langchain_core.utils.pydantic import (
TypeBaseModel,
_create_subset_model,
get_fields,
is_basemodel_subclass,
is_pydantic_v1_subclass,
is_pydantic_v2_subclass,
@@ -204,20 +217,64 @@ def create_schema_from_function(
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
sig = inspect.signature(func)
# Let's ignore `self` and `cls` arguments for class and instance methods
if func.__qualname__ and "." in func.__qualname__:
# Then it likely belongs in a class namespace
in_class = True
else:
in_class = False
has_args = False
has_kwargs = False
for param in sig.parameters.values():
if param.kind == param.VAR_POSITIONAL:
has_args = True
elif param.kind == param.VAR_KEYWORD:
has_kwargs = True
inferred_model = validated.model # type: ignore
filter_args = filter_args if filter_args is not None else FILTERED_ARGS
for arg in filter_args:
if arg in inferred_model.__fields__:
del inferred_model.__fields__[arg]
if filter_args:
filter_args_ = filter_args
else:
# Handle classmethods and instance methods
existing_params: List[str] = list(sig.parameters.keys())
if existing_params and existing_params[0] in ("self", "cls") and in_class:
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
else:
filter_args_ = list(FILTERED_ARGS)
for existing_param in existing_params:
if not include_injected and _is_injected_arg_type(
sig.parameters[existing_param].annotation
):
filter_args_.append(existing_param)
description, arg_descriptions = _infer_arg_descriptions(
func,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
)
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(
inferred_model, func, filter_args=filter_args, include_injected=include_injected
)
valid_properties = []
for field in get_fields(inferred_model):
if not has_args:
if field == "args":
continue
if not has_kwargs:
if field == "kwargs":
continue
if field == "v__duplicate_kwargs": # Internal pydantic field
continue
if field not in filter_args_:
valid_properties.append(field)
return _create_subset_model(
f"{model_name}Schema",
inferred_model,
@@ -274,7 +331,10 @@ class ChildTool(BaseTool):
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[TypeBaseModel] = None
args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field(
default=None, description="The tool schema."
)
"""Pydantic model class to validate and parse the tool's input arguments.
Args schema should be either:
@@ -416,7 +476,7 @@ class ChildTool(BaseTool):
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
key_ = next(iter(get_fields(input_args).keys()))
input_args.validate({key_: tool_input})
return tool_input
else:

View File

@@ -2,14 +2,11 @@ from __future__ import annotations
import textwrap
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolCall
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
FILTERED_ARGS,
@@ -18,13 +15,28 @@ from langchain_core.tools.base import (
create_schema_from_function,
)
from langchain_core.utils.pydantic import TypeBaseModel
from pydantic import BaseModel, Field, SkipValidation
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Type,
Union,
Annotated,
)
class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
..., description="The tool schema."
)
"""The input arguments' schema."""
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""

View File

@@ -11,7 +11,6 @@ from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from langchain_core._api import deprecated
from langchain_core.outputs import LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@@ -82,7 +81,8 @@ class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
response: Optional[LLMResult] = None
# Temporarily, remove but we will completely remove LLMRun
# response: Optional[LLMResult] = None
@deprecated("0.1.0", alternative="Run", removal="1.0")

View File

@@ -22,11 +22,11 @@ from typing import (
cast,
)
from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
from langchain_core._api import deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -84,7 +84,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
removal="1.0",
)
def convert_pydantic_to_openai_function(
model: Type[BaseModel],
model: Type,
*,
name: Optional[str] = None,
description: Optional[str] = None,
@@ -192,11 +192,13 @@ def convert_python_function_to_openai_function(
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
visited: Dict = {}
from pydantic.v1 import BaseModel # pydantic: ignore
model = cast(
Type[BaseModel],
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
)
return convert_pydantic_to_openai_function(model)
return convert_pydantic_to_openai_function(model) # type: ignore
_MAX_TYPED_DICT_RECURSION = 25
@@ -208,6 +210,9 @@ def _convert_any_typed_dicts_to_pydantic(
visited: Dict,
depth: int = 0,
) -> Type:
from pydantic.v1 import Field as Field_v1 # pydantic: ignore
from pydantic.v1 import create_model as create_model_v1 # pydantic: ignore
if type_ in visited:
return visited[type_]
elif depth >= _MAX_TYPED_DICT_RECURSION:
@@ -241,7 +246,7 @@ def _convert_any_typed_dicts_to_pydantic(
field_kwargs["description"] = arg_desc
else:
pass
fields[arg] = (new_arg_type, Field(**field_kwargs))
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
else:
new_arg_type = _convert_any_typed_dicts_to_pydantic(
arg_type, depth=depth + 1, visited=visited
@@ -249,8 +254,8 @@ def _convert_any_typed_dicts_to_pydantic(
field_kwargs = {"default": ...}
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field(**field_kwargs))
model = create_model(typed_dict.__name__, **fields)
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
model = create_model_v1(typed_dict.__name__, **fields)
model.__doc__ = description
visited[typed_dict] = model
return model

View File

@@ -8,8 +8,9 @@ from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
import pydantic # pydantic: ignore
from langchain_core.pydantic_v1 import BaseModel, root_validator
from pydantic import BaseModel, root_validator # pydantic: ignore
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue # pydantic: ignore
from pydantic_core import core_schema # pydantic: ignore
def get_pydantic_major_version() -> int:
@@ -27,7 +28,6 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
@@ -146,7 +146,7 @@ def pre_init(func: Callable) -> Any:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.__fields__
fields = cls.model_fields
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
@@ -155,9 +155,13 @@ def pre_init(func: Callable) -> Any:
if cls.Config.allow_population_by_field_name:
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if hasattr(cls, "model_config"):
if cls.model_config.get("populate_by_name"):
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if name not in values or values[name] is None:
if not field_info.required:
if not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
@@ -169,6 +173,44 @@ def pre_init(func: Callable) -> Any:
return wrapper
class _IgnoreUnserializable(GenerateJsonSchema):
"""A JSON schema generator that ignores unknown types.
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
"""
def handle_invalid_for_json_schema(
self, schema: core_schema.CoreSchema, error_info: str
) -> JsonSchemaValue:
return {}
def v1_repr(obj: BaseModel) -> str:
"""Return the schema of the object as a string.
Get a repr for the pydantic object which is consistent with pydantic.v1.
"""
if not is_basemodel_instance(obj):
raise TypeError(f"Expected a pydantic BaseModel, got {type(obj)}")
repr_ = []
for name, field in get_fields(obj).items():
value = getattr(obj, name)
if isinstance(value, BaseModel):
repr_.append(f"{name}={v1_repr(value)}")
else:
if not field.is_required():
if not value:
continue
if field.default == value:
continue
repr_.append(f"{name}={repr(value)}")
args = ", ".join(repr_)
return f"{obj.__class__.__name__}({args})"
def _create_subset_model_v1(
name: str,
model: Type[BaseModel],
@@ -178,12 +220,20 @@ def _create_subset_model_v1(
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
from langchain_core.pydantic_v1 import create_model
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import create_model # pydantic: ignore
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import create_model # type: ignore # pydantic: ignore
else:
raise NotImplementedError(
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
)
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
# Using pydantic v1 so can access __fields__ as a dict.
field = model.__fields__[field_name] # type: ignore
t = (
# this isn't perfect but should work for most functions
field.outer_type_

View File

@@ -1,12 +1,12 @@
"""A fake callback handler for testing purposes."""
from itertools import chain
from pydantic import BaseModel
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel
class BaseFakeCallbackHandler(BaseModel):
@@ -256,7 +256,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
return self
@@ -390,5 +391,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_text_common()
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
return self

View File

@@ -9,7 +9,6 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
AnyStr,
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
@@ -70,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
]
assert len({chunk.id for chunk in chunks}) == 1
@@ -89,29 +88,23 @@ async def test_generic_fake_chat_model_stream() -> None:
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id=AnyStr(),
_AnyIdAIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'},
},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={"function_call": {"arguments": ","}},
id=AnyStr(),
_AnyIdAIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
},
id=AnyStr(),
),
]
assert len({chunk.id for chunk in chunks}) == 1

View File

@@ -3,6 +3,7 @@ import time
from langchain_core.caches import InMemoryCache
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.rate_limiters import InMemoryRateLimiter
from typing import Optional as Optional
def test_rate_limit_invoke() -> None:
@@ -219,6 +220,8 @@ class SerializableModel(GenericFakeChatModel):
def is_lc_serializable(cls) -> bool:
return True
SerializableModel.model_rebuild()
def test_serialization_with_rate_limiter() -> None:
"""Test model serialization with rate limiter."""

View File

@@ -1,6 +1,7 @@
"""Module to test base parser implementations."""
from typing import List
from typing import Optional as Optional
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel
@@ -46,6 +47,8 @@ def test_base_generation_parser() -> None:
assert isinstance(content, str)
return content.swapcase() # type: ignore
StrInvertCase.update_forward_refs()
model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")]))
chain = model | StrInvertCase()
assert chain.invoke("") == "HeLLO"

View File

@@ -3,6 +3,7 @@
dict({
'definitions': dict({
'AIMessage': dict({
'additionalProperties': True,
'description': '''
Message from an AI.
@@ -44,8 +45,16 @@
'type': 'boolean',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
@@ -57,8 +66,16 @@
'type': 'array',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -74,6 +91,7 @@
'type': 'array',
}),
'type': dict({
'const': 'ai',
'default': 'ai',
'enum': list([
'ai',
@@ -82,7 +100,15 @@
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
'anyOf': list([
dict({
'$ref': '#/definitions/UsageMetadata',
}),
dict({
'type': 'null',
}),
]),
'default': None,
}),
}),
'required': list([
@@ -92,6 +118,7 @@
'type': 'object',
}),
'ChatMessage': dict({
'additionalProperties': True,
'description': 'Message that can be assigned an arbitrary speaker (i.e. role).',
'properties': dict({
'additional_kwargs': dict({
@@ -120,12 +147,28 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -136,6 +179,7 @@
'type': 'string',
}),
'type': dict({
'const': 'chat',
'default': 'chat',
'enum': list([
'chat',
@@ -152,6 +196,7 @@
'type': 'object',
}),
'FunctionMessage': dict({
'additionalProperties': True,
'description': '''
Message for passing the result of executing a tool back to a model.
@@ -189,8 +234,16 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
@@ -201,6 +254,7 @@
'type': 'object',
}),
'type': dict({
'const': 'function',
'default': 'function',
'enum': list([
'function',
@@ -217,6 +271,7 @@
'type': 'object',
}),
'HumanMessage': dict({
'additionalProperties': True,
'description': '''
Message from a human.
@@ -273,18 +328,35 @@
'type': 'boolean',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
}),
'type': dict({
'const': 'human',
'default': 'human',
'enum': list([
'human',
@@ -300,24 +372,59 @@
'type': 'object',
}),
'InvalidToolCall': dict({
'description': '''
Allowance for errors made by LLM.
Here we add an `error` key to surface errors made during generation
(e.g., invalid JSON arguments.)
''',
'properties': dict({
'args': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Args',
'type': 'string',
}),
'error': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Error',
'type': 'string',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Name',
'type': 'string',
}),
'type': dict({
'const': 'invalid_tool_call',
'enum': list([
'invalid_tool_call',
]),
@@ -335,6 +442,7 @@
'type': 'object',
}),
'SystemMessage': dict({
'additionalProperties': True,
'description': '''
Message for priming AI behavior.
@@ -386,18 +494,35 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
}),
'type': dict({
'const': 'system',
'default': 'system',
'enum': list([
'system',
@@ -413,20 +538,44 @@
'type': 'object',
}),
'ToolCall': dict({
'description': '''
Represents a request to call a tool.
Example:
.. code-block:: python
{
"name": "foo",
"args": {"a": 1},
"id": "123"
}
This represents a request to call the tool named "foo" with arguments {"a": 1}
and an identifier of "123".
''',
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
'type': dict({
'const': 'tool_call',
'enum': list([
'tool_call',
]),
@@ -443,6 +592,7 @@
'type': 'object',
}),
'ToolMessage': dict({
'additionalProperties': True,
'description': '''
Message for passing the result of executing a tool back to a model.
@@ -513,12 +663,28 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -538,6 +704,7 @@
'type': 'string',
}),
'type': dict({
'const': 'tool',
'default': 'tool',
'enum': list([
'tool',
@@ -554,6 +721,21 @@
'type': 'object',
}),
'UsageMetadata': dict({
'description': '''
Usage metadata for a message, such as token counts.
This is a standard representation of token usage that is consistent across models.
Example:
.. code-block:: python
{
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30
}
''',
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
@@ -620,6 +802,7 @@
dict({
'definitions': dict({
'AIMessage': dict({
'additionalProperties': True,
'description': '''
Message from an AI.
@@ -661,8 +844,16 @@
'type': 'boolean',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
@@ -674,8 +865,16 @@
'type': 'array',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -691,6 +890,7 @@
'type': 'array',
}),
'type': dict({
'const': 'ai',
'default': 'ai',
'enum': list([
'ai',
@@ -699,7 +899,15 @@
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
'anyOf': list([
dict({
'$ref': '#/definitions/UsageMetadata',
}),
dict({
'type': 'null',
}),
]),
'default': None,
}),
}),
'required': list([
@@ -709,6 +917,7 @@
'type': 'object',
}),
'ChatMessage': dict({
'additionalProperties': True,
'description': 'Message that can be assigned an arbitrary speaker (i.e. role).',
'properties': dict({
'additional_kwargs': dict({
@@ -737,12 +946,28 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -753,6 +978,7 @@
'type': 'string',
}),
'type': dict({
'const': 'chat',
'default': 'chat',
'enum': list([
'chat',
@@ -769,6 +995,7 @@
'type': 'object',
}),
'FunctionMessage': dict({
'additionalProperties': True,
'description': '''
Message for passing the result of executing a tool back to a model.
@@ -806,8 +1033,16 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
@@ -818,6 +1053,7 @@
'type': 'object',
}),
'type': dict({
'const': 'function',
'default': 'function',
'enum': list([
'function',
@@ -834,6 +1070,7 @@
'type': 'object',
}),
'HumanMessage': dict({
'additionalProperties': True,
'description': '''
Message from a human.
@@ -890,18 +1127,35 @@
'type': 'boolean',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
}),
'type': dict({
'const': 'human',
'default': 'human',
'enum': list([
'human',
@@ -917,24 +1171,59 @@
'type': 'object',
}),
'InvalidToolCall': dict({
'description': '''
Allowance for errors made by LLM.
Here we add an `error` key to surface errors made during generation
(e.g., invalid JSON arguments.)
''',
'properties': dict({
'args': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Args',
'type': 'string',
}),
'error': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Error',
'type': 'string',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Name',
'type': 'string',
}),
'type': dict({
'const': 'invalid_tool_call',
'enum': list([
'invalid_tool_call',
]),
@@ -952,6 +1241,7 @@
'type': 'object',
}),
'SystemMessage': dict({
'additionalProperties': True,
'description': '''
Message for priming AI behavior.
@@ -1003,18 +1293,35 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
}),
'type': dict({
'const': 'system',
'default': 'system',
'enum': list([
'system',
@@ -1030,20 +1337,44 @@
'type': 'object',
}),
'ToolCall': dict({
'description': '''
Represents a request to call a tool.
Example:
.. code-block:: python
{
"name": "foo",
"args": {"a": 1},
"id": "123"
}
This represents a request to call the tool named "foo" with arguments {"a": 1}
and an identifier of "123".
''',
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
'type': dict({
'const': 'tool_call',
'enum': list([
'tool_call',
]),
@@ -1060,6 +1391,7 @@
'type': 'object',
}),
'ToolMessage': dict({
'additionalProperties': True,
'description': '''
Message for passing the result of executing a tool back to a model.
@@ -1130,12 +1462,28 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -1155,6 +1503,7 @@
'type': 'string',
}),
'type': dict({
'const': 'tool',
'default': 'tool',
'enum': list([
'tool',
@@ -1171,6 +1520,21 @@
'type': 'object',
}),
'UsageMetadata': dict({
'description': '''
Usage metadata for a message, such as token counts.
This is a standard representation of token usage that is consistent across models.
Example:
.. code-block:: python
{
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30
}
''',
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
@@ -1436,7 +1800,7 @@
'type': 'constructor',
})
# ---
# name: test_chat_prompt_w_msgs_placeholder_ser_des[placholder]
# name: test_chat_prompt_w_msgs_placeholder_ser_des[placeholder]
dict({
'id': list([
'langchain',

View File

@@ -6,7 +6,9 @@ from typing import Any, List, Union
import pytest
from syrupy import SnapshotAssertion
from langchain_core._api.deprecation import LangChainPendingDeprecationWarning
from langchain_core._api.deprecation import (
LangChainPendingDeprecationWarning,
)
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
@@ -810,7 +812,7 @@ def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) ->
prompt = ChatPromptTemplate.from_messages(
[("system", "foo"), MessagesPlaceholder("bar"), ("human", "baz")]
)
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placholder")
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placeholder")
assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar")
assert dumpd(prompt) == snapshot(name="chat_prompt")
assert load(dumpd(prompt)) == prompt

View File

@@ -1,6 +1,7 @@
from functools import partial
from inspect import isclass
from typing import Any, Dict, Type, Union, cast
from typing import Optional as Optional
from langchain_core.language_models import FakeListChatModel
from langchain_core.load.dump import dumps
@@ -34,6 +35,9 @@ class FakeStructuredChatModel(FakeListChatModel):
return "fake-messages-list-chat-model"
FakeStructuredChatModel.update_forward_refs()
def test_structured_prompt_pydantic() -> None:
class OutputSchema(BaseModel):
name: str

View File

@@ -1,20 +1,10 @@
"""Helper utilities for pydantic.
This module includes helper utilities to ease the migration from pydantic v1 to v2.
They're meant to be used in the following way:
1) Use utility code to help (selected) unit tests pass without modifications
2) Upgrade the unit tests to match pydantic 2
3) Stop using the utility code
"""
from typing import Any
from langchain_core.utils.pydantic import is_basemodel_subclass
# Function to replace allOf with $ref
def _replace_all_of_with_ref(schema: Any) -> None:
"""Replace allOf with $ref in the schema."""
def replace_all_of_with_ref(schema: Any) -> None:
if isinstance(schema, dict):
# If the schema has an allOf key with a single item that contains a $ref
if (
@@ -30,13 +20,13 @@ def _replace_all_of_with_ref(schema: Any) -> None:
# Recursively process nested schemas
for key, value in schema.items():
if isinstance(value, (dict, list)):
_replace_all_of_with_ref(value)
replace_all_of_with_ref(value)
elif isinstance(schema, list):
for item in schema:
_replace_all_of_with_ref(item)
replace_all_of_with_ref(item)
def _remove_bad_none_defaults(schema: Any) -> None:
def remove_all_none_default(schema: Any) -> None:
"""Removing all none defaults.
Pydantic v1 did not generate these, but Pydantic v2 does.
@@ -56,39 +46,48 @@ def _remove_bad_none_defaults(schema: Any) -> None:
break # Null type explicitly defined
else:
del value["default"]
_remove_bad_none_defaults(value)
remove_all_none_default(value)
elif isinstance(value, list):
for item in value:
_remove_bad_none_defaults(item)
remove_all_none_default(item)
elif isinstance(schema, list):
for item in schema:
_remove_bad_none_defaults(item)
remove_all_none_default(item)
def _remove_enum_description(obj: Any) -> None:
"""Remove the description from enums."""
if isinstance(obj, dict):
if "enum" in obj:
if "description" in obj and obj["description"] == "An enumeration.":
del obj["description"]
for key, value in obj.items():
_remove_enum_description(value)
elif isinstance(obj, list):
for item in obj:
_remove_enum_description(item)
def _schema(obj: Any) -> dict:
"""Get the schema of a pydantic model in the pydantic v1 style.
This will attempt to map the schema as close as possible to the pydantic v1 schema.
"""
"""Return the schema of the object."""
# Remap to old style schema
if not is_basemodel_subclass(obj):
raise TypeError(
f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
)
if not hasattr(obj, "model_json_schema"): # V1 model
return obj.schema()
# Then we're using V2 models internally.
raise AssertionError(
"Hi there! Looks like you're attempting to upgrade to Pydantic v2. If so: \n"
"1) remove this exception\n"
"2) confirm that the old unit tests pass, and if not look for difference\n"
"3) update the unit tests to match the new schema\n"
"4) remove this utility function\n"
)
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
if "$defs" in schema_:
schema_["definitions"] = schema_["$defs"]
del schema_["$defs"]
_replace_all_of_with_ref(schema_)
_remove_bad_none_defaults(schema_)
if "default" in schema_ and schema_["default"] is None:
del schema_["default"]
replace_all_of_with_ref(schema_)
remove_all_none_default(schema_)
_remove_enum_description(schema_)
return schema_

View File

@@ -334,31 +334,9 @@
}),
dict({
'data': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'$ref': '#/definitions/AIMessage',
}),
dict({
'$ref': '#/definitions/HumanMessage',
}),
dict({
'$ref': '#/definitions/ChatMessage',
}),
dict({
'$ref': '#/definitions/SystemMessage',
}),
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
'definitions': dict({
'$defs': dict({
'AIMessage': dict({
'additionalProperties': True,
'description': '''
Message from an AI.
@@ -400,21 +378,37 @@
'type': 'boolean',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
'$ref': '#/$defs/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -424,12 +418,13 @@
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
'$ref': '#/$defs/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
}),
'type': dict({
'const': 'ai',
'default': 'ai',
'enum': list([
'ai',
@@ -438,7 +433,15 @@
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
'anyOf': list([
dict({
'$ref': '#/$defs/UsageMetadata',
}),
dict({
'type': 'null',
}),
]),
'default': None,
}),
}),
'required': list([
@@ -448,6 +451,7 @@
'type': 'object',
}),
'ChatMessage': dict({
'additionalProperties': True,
'description': 'Message that can be assigned an arbitrary speaker (i.e. role).',
'properties': dict({
'additional_kwargs': dict({
@@ -476,12 +480,28 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -492,6 +512,7 @@
'type': 'string',
}),
'type': dict({
'const': 'chat',
'default': 'chat',
'enum': list([
'chat',
@@ -508,6 +529,7 @@
'type': 'object',
}),
'FunctionMessage': dict({
'additionalProperties': True,
'description': '''
Message for passing the result of executing a tool back to a model.
@@ -545,8 +567,16 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
@@ -557,6 +587,7 @@
'type': 'object',
}),
'type': dict({
'const': 'function',
'default': 'function',
'enum': list([
'function',
@@ -573,6 +604,7 @@
'type': 'object',
}),
'HumanMessage': dict({
'additionalProperties': True,
'description': '''
Message from a human.
@@ -629,18 +661,35 @@
'type': 'boolean',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
}),
'type': dict({
'const': 'human',
'default': 'human',
'enum': list([
'human',
@@ -656,24 +705,59 @@
'type': 'object',
}),
'InvalidToolCall': dict({
'description': '''
Allowance for errors made by LLM.
Here we add an `error` key to surface errors made during generation
(e.g., invalid JSON arguments.)
''',
'properties': dict({
'args': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Args',
'type': 'string',
}),
'error': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Error',
'type': 'string',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Name',
'type': 'string',
}),
'type': dict({
'const': 'invalid_tool_call',
'enum': list([
'invalid_tool_call',
]),
@@ -691,6 +775,7 @@
'type': 'object',
}),
'SystemMessage': dict({
'additionalProperties': True,
'description': '''
Message for priming AI behavior.
@@ -742,18 +827,35 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
}),
'type': dict({
'const': 'system',
'default': 'system',
'enum': list([
'system',
@@ -769,20 +871,44 @@
'type': 'object',
}),
'ToolCall': dict({
'description': '''
Represents a request to call a tool.
Example:
.. code-block:: python
{
"name": "foo",
"args": {"a": 1},
"id": "123"
}
This represents a request to call the tool named "foo" with arguments {"a": 1}
and an identifier of "123".
''',
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
'type': dict({
'const': 'tool_call',
'enum': list([
'tool_call',
]),
@@ -799,6 +925,7 @@
'type': 'object',
}),
'ToolMessage': dict({
'additionalProperties': True,
'description': '''
Message for passing the result of executing a tool back to a model.
@@ -845,6 +972,7 @@
'type': 'object',
}),
'artifact': dict({
'default': None,
'title': 'Artifact',
}),
'content': dict({
@@ -869,12 +997,28 @@
'title': 'Content',
}),
'id': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Id',
'type': 'string',
}),
'name': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Name',
'type': 'string',
}),
'response_metadata': dict({
'title': 'Response Metadata',
@@ -894,6 +1038,7 @@
'type': 'string',
}),
'type': dict({
'const': 'tool',
'default': 'tool',
'enum': list([
'tool',
@@ -910,6 +1055,21 @@
'type': 'object',
}),
'UsageMetadata': dict({
'description': '''
Usage metadata for a message, such as token counts.
This is a standard representation of token usage that is consistent across models.
Example:
.. code-block:: python
{
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30
}
''',
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
@@ -933,6 +1093,29 @@
'type': 'object',
}),
}),
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'$ref': '#/$defs/AIMessage',
}),
dict({
'$ref': '#/$defs/HumanMessage',
}),
dict({
'$ref': '#/$defs/ChatMessage',
}),
dict({
'$ref': '#/$defs/SystemMessage',
}),
dict({
'$ref': '#/$defs/FunctionMessage',
}),
dict({
'$ref': '#/$defs/ToolMessage',
}),
]),
'title': 'RunnableParallel<as_list,as_str>Input',
}),
'id': 3,
@@ -952,6 +1135,10 @@
'title': 'As Str',
}),
}),
'required': list([
'as_list',
'as_str',
]),
'title': 'RunnableParallel<as_list,as_str>Output',
'type': 'object',
}),

View File

@@ -455,8 +455,9 @@ def test_get_input_schema_input_dict() -> None:
def test_get_input_schema_input_messages() -> None:
class RunnableWithChatHistoryInput(BaseModel):
__root__: Sequence[BaseMessage]
from pydantic import RootModel # pydantic: ignore
RunnableWithMessageHistoryInput = RootModel[Sequence[BaseMessage]]
runnable = RunnableLambda(
lambda messages: {
@@ -478,9 +479,9 @@ def test_get_input_schema_input_messages() -> None:
with_history = RunnableWithMessageHistory(
runnable, get_session_history, output_messages_key="output"
)
assert _schema(with_history.get_input_schema()) == _schema(
RunnableWithChatHistoryInput
)
expected_schema = _schema(RunnableWithMessageHistoryInput)
expected_schema["title"] = "RunnableWithChatHistoryInput"
assert _schema(with_history.get_input_schema()) == expected_schema
def test_using_custom_config_specs() -> None:

View File

@@ -19,6 +19,7 @@ from uuid import UUID
import pytest
from freezegun import freeze_time
from pydantic import BaseModel
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from typing_extensions import TypedDict
@@ -57,7 +58,6 @@ from langchain_core.prompts import (
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
AddableDict,
@@ -89,7 +89,7 @@ from langchain_core.tracers import (
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.pydantic_utils import _schema
from tests.unit_tests.pydantic_utils import _schema, replace_all_of_with_ref
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
@@ -313,12 +313,14 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"metadata": {"title": "Metadata", "type": "object"},
"id": {
"title": "Id",
"type": "string",
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
},
"type": {
"title": "Type",
"enum": ["Document"],
"default": "Document",
"const": "Document",
"type": "string",
},
},
@@ -329,7 +331,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
assert _schema(fake_llm.input_schema) == snapshot
assert _schema(fake_llm.input_schema) == snapshot(name="fake_llm_input_schema")
assert _schema(fake_llm.output_schema) == {
"title": "FakeListLLMOutput",
"type": "string",
@@ -337,8 +339,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
fake_chat = FakeListChatModel(responses=["a"]) # str -> List[List[str]]
assert _schema(fake_chat.input_schema) == snapshot
assert _schema(fake_chat.output_schema) == snapshot
assert _schema(fake_chat.input_schema) == snapshot(name="fake_chat_input_schema")
assert _schema(fake_chat.output_schema) == snapshot(name="fake_chat_output_schema")
chat_prompt = ChatPromptTemplate.from_messages(
[
@@ -362,7 +364,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"properties": {"name": {"title": "Name", "type": "string"}},
"required": ["name"],
}
assert _schema(prompt.output_schema) == snapshot
assert _schema(prompt.output_schema) == snapshot(name="prompt_output_schema")
prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()
@@ -379,11 +381,15 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"type": "array",
"title": "RunnableEach<PromptTemplate>Input",
}
assert _schema(prompt_mapper.output_schema) == snapshot
assert _schema(prompt_mapper.output_schema) == snapshot(
name="prompt_mapper_output_schema"
)
list_parser = CommaSeparatedListOutputParser()
assert _schema(list_parser.input_schema) == snapshot
assert _schema(list_parser.input_schema) == snapshot(
name="list_parser_input_schema"
)
assert _schema(list_parser.output_schema) == {
"title": "CommaSeparatedListOutputParserOutput",
"type": "array",
@@ -407,19 +413,26 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
router: Runnable = RouterRunnable({})
assert _schema(router.input_schema) == {
"title": "RouterRunnableInput",
"$ref": "#/definitions/RouterInput",
"definitions": {
"RouterInput": {
"title": "RouterInput",
"type": "object",
"description": "Router input.\n"
"\n"
"Attributes:\n"
" key: The key to route "
"on.\n"
" input: The input to pass "
"to the selected Runnable.",
"properties": {
"key": {"title": "Key", "type": "string"},
"input": {"title": "Input"},
"key": {"title": "Key", "type": "string"},
},
"required": ["key", "input"],
"title": "RouterInput",
"type": "object",
}
},
"title": "RouterRunnableInput",
}
assert _schema(router.output_schema) == {"title": "RouterRunnableOutput"}
@@ -451,6 +464,20 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"items": {"type": "string"},
},
},
"required": ["original", "as_list", "length"],
}
# Add a test for schema of runnable assign
def foo(x: int) -> int:
return x
foo_ = RunnableLambda(foo)
assert foo_.assign(bar=lambda x: "foo").get_output_schema().schema() == {
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
"required": ["root", "bar"],
"title": "RunnableAssignOutput",
"type": "object",
}
@@ -469,6 +496,7 @@ def test_passthrough_assign_schema() -> None:
"properties": {"question": {"title": "Question", "type": "string"}},
"title": "RunnableSequenceInput",
"type": "object",
"required": ["question"],
}
assert _schema(seq_w_assign.output_schema) == {
"title": "FakeListLLMOutput",
@@ -486,6 +514,7 @@ def test_passthrough_assign_schema() -> None:
"properties": {"question": {"title": "Question"}},
"title": "RunnableParallel<context>Input",
"type": "object",
"required": ["question"],
}
@@ -498,6 +527,7 @@ def test_lambda_schemas() -> None:
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}},
"required": ["hello"],
}
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
@@ -505,6 +535,7 @@ def test_lambda_schemas() -> None:
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
"required": ["bye", "hello"],
}
def get_value(input): # type: ignore[no-untyped-def]
@@ -514,6 +545,7 @@ def test_lambda_schemas() -> None:
"title": "get_value_input",
"type": "object",
"properties": {"variable_name": {"title": "Variable Name"}},
"required": ["variable_name"],
}
async def aget_value(input): # type: ignore[no-untyped-def]
@@ -526,6 +558,7 @@ def test_lambda_schemas() -> None:
"another": {"title": "Another"},
"variable_name": {"title": "Variable Name"},
},
"required": ["another", "variable_name"],
}
async def aget_values(input): # type: ignore[no-untyped-def]
@@ -542,6 +575,7 @@ def test_lambda_schemas() -> None:
"variable_name": {"title": "Variable Name"},
"yo": {"title": "Yo"},
},
"required": ["variable_name", "yo"],
}
class InputType(TypedDict):
@@ -622,6 +656,25 @@ def test_with_types_with_type_generics() -> None:
)
def test_schema_with_itemgetter() -> None:
"""Test runnable with itemgetter."""
foo = RunnableLambda(itemgetter("hello"))
assert _schema(foo.input_schema) == {
"properties": {"hello": {"title": "Hello"}},
"required": ["hello"],
"title": "RunnableLambdaInput",
"type": "object",
}
prompt = ChatPromptTemplate.from_template("what is {language}?")
chain: Runnable = {"language": itemgetter("language")} | prompt
assert _schema(chain.input_schema) == {
"properties": {"language": {"title": "Language"}},
"required": ["language"],
"title": "RunnableParallel<language>Input",
"type": "object",
}
def test_schema_complex_seq() -> None:
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
prompt2 = ChatPromptTemplate.from_template(
@@ -650,6 +703,7 @@ def test_schema_complex_seq() -> None:
"person": {"title": "Person", "type": "string"},
"language": {"title": "Language"},
},
"required": ["person", "language"],
}
assert _schema(chain2.output_schema) == {
@@ -953,66 +1007,69 @@ def test_configurable_fields_prefix_keys() -> None:
chain = prompt | fake_llm
assert _schema(chain.config_schema()) == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"LLM": {
"title": "LLM",
"description": "An enumeration.",
"enum": ["chat", "default"],
"type": "string",
},
"Chat_Responses": {
"title": "Chat Responses",
"description": "An enumeration.",
"enum": ["hello", "bye", "helpful"],
"type": "string",
},
"Prompt_Template": {
"title": "Prompt Template",
"description": "An enumeration.",
"enum": ["hello", "good_morning"],
"title": "Chat Responses",
"type": "string",
},
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"prompt_template": {
"title": "Prompt Template",
"description": "The prompt template for this chain",
"default": "hello",
"allOf": [{"$ref": "#/definitions/Prompt_Template"}],
"chat_sleep": {
"anyOf": [{"type": "number"}, {"type": "null"}],
"default": None,
"title": "Chat " "Sleep",
},
"llm": {
"title": "LLM",
"$ref": "#/definitions/LLM",
"default": "default",
"allOf": [{"$ref": "#/definitions/LLM"}],
"title": "LLM",
},
# not prefixed because marked as shared
"chat_sleep": {
"title": "Chat Sleep",
"type": "number",
},
# prefixed for "chat" option
"llm==chat/responses": {
"title": "Chat Responses",
"default": ["hello", "bye"],
"type": "array",
"items": {"$ref": "#/definitions/Chat_Responses"},
},
# prefixed for "default" option
"llm==default/responses": {
"title": "LLM Responses",
"description": "A list of fake responses for this LLM",
"default": ["a"],
"title": "Chat " "Responses",
"type": "array",
},
"llm==default/responses": {
"default": ["a"],
"description": "A "
"list "
"of "
"fake "
"responses "
"for "
"this "
"LLM",
"items": {"type": "string"},
"title": "LLM " "Responses",
"type": "array",
},
"prompt_template": {
"$ref": "#/definitions/Prompt_Template",
"default": "hello",
"description": "The "
"prompt "
"template "
"for "
"this "
"chain",
"title": "Prompt " "Template",
},
},
"title": "Configurable",
"type": "object",
},
"LLM": {"enum": ["chat", "default"], "title": "LLM", "type": "string"},
"Prompt_Template": {
"enum": ["hello", "good_morning"],
"title": "Prompt Template",
"type": "string",
},
},
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"title": "RunnableSequenceConfig",
"type": "object",
}
@@ -1061,26 +1118,22 @@ def test_configurable_fields_example() -> None:
chain_configurable = prompt | fake_llm | (lambda x: {"name": x}) | prompt | fake_llm
assert chain_configurable.invoke({"name": "John"}) == "a"
assert _schema(chain_configurable.config_schema()) == {
expected = {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"LLM": {
"title": "LLM",
"description": "An enumeration.",
"enum": ["chat", "default"],
"type": "string",
},
"Chat_Responses": {
"description": "An enumeration.",
"enum": ["hello", "bye", "helpful"],
"title": "Chat Responses",
"type": "string",
},
"Prompt_Template": {
"description": "An enumeration.",
"enum": ["hello", "good_morning"],
"title": "Prompt Template",
"type": "string",
@@ -1118,6 +1171,10 @@ def test_configurable_fields_example() -> None:
},
}
replace_all_of_with_ref(expected)
assert _schema(chain_configurable.config_schema()) == expected
assert (
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
{"name": "John"}
@@ -3173,6 +3230,7 @@ def test_map_stream() -> None:
"hello": {"title": "Hello", "type": "string"},
"llm": {"title": "Llm", "type": "string"},
},
"required": ["llm", "hello"],
}
stream = chain_pick_two.stream({"question": "What is your name?"})
@@ -3190,6 +3248,12 @@ def test_map_stream() -> None:
{"llm": "i"},
{"chat": AIMessageChunk(content="i")},
]
if not ( # TODO(Rewrite properly) statement above
streamed_chunks[0] == {"llm": "i"}
or {"chat": _AnyIdAIMessageChunk(content="i")}
):
raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}")
assert len(streamed_chunks) == len(llm_res) + len(chat_res)
@@ -3544,6 +3608,7 @@ def test_deep_stream_assign() -> None:
"str": {"title": "Str", "type": "string"},
"hello": {"title": "Hello", "type": "string"},
},
"required": ["str", "hello"],
}
chunks = []
@@ -3595,6 +3660,7 @@ def test_deep_stream_assign() -> None:
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
"required": ["str", "hello"],
}
chunks = []
@@ -3670,6 +3736,7 @@ async def test_deep_astream_assign() -> None:
"str": {"title": "Str", "type": "string"},
"hello": {"title": "Hello", "type": "string"},
},
"required": ["str", "hello"],
}
chunks = []
@@ -3721,6 +3788,7 @@ async def test_deep_astream_assign() -> None:
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
"required": ["str", "hello"],
}
chunks = []
@@ -4362,7 +4430,10 @@ def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
assert isinstance(body, Runnable)
assert isinstance(runnable.default, Runnable)
assert _schema(runnable.input_schema) == {"title": "RunnableBranchInput"}
assert _schema(runnable.input_schema) == {
"title": "RunnableBranchInput",
"type": "integer",
}
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:

View File

@@ -3,6 +3,7 @@
import sys
from itertools import cycle
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
from typing import Optional as Optional
import pytest
@@ -1176,6 +1177,9 @@ class HardCodedRetriever(BaseRetriever):
return self.documents
HardCodedRetriever.update_forward_refs()
async def test_event_stream_with_retriever() -> None:
"""Test the event stream with a retriever."""
retriever = HardCodedRetriever(

View File

@@ -0,0 +1,25 @@
from pydantic import BaseModel, v1
from tests.unit_tests.pydantic_utils import _schema
def test_schemas() -> None:
"""Test schema remapping for the two pydantic versions."""
class Bar(BaseModel):
baz: int
class Foo(BaseModel):
bar: Bar
schema_2 = _schema(Foo)
class Bar(v1.BaseModel): # type: ignore[no-redef]
baz: int
class Foo(v1.BaseModel): # type: ignore[no-redef]
bar: Bar
schema_1 = _schema(Foo)
assert schema_1 == schema_2

View File

@@ -874,15 +874,13 @@ async def test_async_validation_error_handling_non_validation_error(
def test_optional_subset_model_rewrite() -> None:
class MyModel(BaseModel):
a: Optional[str]
a: Optional[str] = None
b: str
c: Optional[List[Optional[str]]]
c: Optional[List[Optional[str]]] = None
model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])
assert "a" not in _schema(model2)["required"] # should be optional
assert "b" in _schema(model2)["required"] # should be required
assert "c" not in _schema(model2)["required"] # should be optional
assert set(_schema(model2)["required"]) == {"b"}
@pytest.mark.parametrize(

View File

@@ -411,17 +411,17 @@ def test_multiple_tool_calls() -> None:
{
"id": messages[2].tool_call_id,
"type": "function",
"function": {"name": "FakeCall", "arguments": '{"data": "ToolCall1"}'},
"function": {"name": "FakeCall", "arguments": '{"data":"ToolCall1"}'},
},
{
"id": messages[3].tool_call_id,
"type": "function",
"function": {"name": "FakeCall", "arguments": '{"data": "ToolCall2"}'},
"function": {"name": "FakeCall", "arguments": '{"data":"ToolCall2"}'},
},
{
"id": messages[4].tool_call_id,
"type": "function",
"function": {"name": "FakeCall", "arguments": '{"data": "ToolCall3"}'},
"function": {"name": "FakeCall", "arguments": '{"data":"ToolCall3"}'},
},
]
@@ -442,7 +442,7 @@ def test_tool_outputs() -> None:
{
"id": messages[2].tool_call_id,
"type": "function",
"function": {"name": "FakeCall", "arguments": '{"data": "ToolCall1"}'},
"function": {"name": "FakeCall", "arguments": '{"data":"ToolCall1"}'},
},
]
assert messages[2].content == "Output1"
@@ -698,7 +698,8 @@ def test__convert_typed_dict_to_openai_function(
@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
def test__convert_typed_dict_to_openai_function_fail(typed_dict: Type) -> None:
class Tool(typed_dict):
arg1: MutableSet # Pydantic doesn't support
arg1: MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
# Error should be raised since we're using v1 code path here
with pytest.raises(TypeError):
_convert_typed_dict_to_openai_function(Tool)

View File

@@ -117,7 +117,7 @@ async def test_inmemory_filter() -> None:
# Check sync version
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
assert output == [Document(page_content="foo", metadata={"id": 1}, id=AnyStr())]
assert output == [_AnyIdDocument(page_content="foo", metadata={"id": 1})]
# filter with not stored document id
output = await store.asimilarity_search(