mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-13 06:16:26 +00:00
Compare commits
99 Commits
langchain-
...
eugene/mer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41b17ee262 | ||
|
|
c240015491 | ||
|
|
658e0f471e | ||
|
|
33445d4bbd | ||
|
|
ac29cdfd81 | ||
|
|
9230ee402c | ||
|
|
4b5d57cd1c | ||
|
|
3d561c3e6d | ||
|
|
09f9d3e972 | ||
|
|
c1e6e7d020 | ||
|
|
6f79443ab5 | ||
|
|
58c4e1ef86 | ||
|
|
de30b04f37 | ||
|
|
f7a455299e | ||
|
|
76043abd47 | ||
|
|
61cdb9ccce | ||
|
|
1ef3fa54fc | ||
|
|
3856e3b02a | ||
|
|
035f09f20d | ||
|
|
63f7a5ab68 | ||
|
|
8447d9f6f1 | ||
|
|
95db9e9258 | ||
|
|
0d1b93774b | ||
|
|
bcce3a2865 | ||
|
|
4a478d82bd | ||
|
|
a0c3657442 | ||
|
|
72c5c28b4d | ||
|
|
fe6f2f724b | ||
|
|
88d347e90c | ||
|
|
741b50d4fd | ||
|
|
24c6825345 | ||
|
|
32824aa55c | ||
|
|
f6924653ea | ||
|
|
66e8594b89 | ||
|
|
3b9f061eac | ||
|
|
76b6ee290d | ||
|
|
22957311fe | ||
|
|
f9df75c8cc | ||
|
|
ece0ab8539 | ||
|
|
4ddd9e5f23 | ||
|
|
f8e95e5735 | ||
|
|
6515b2f77b | ||
|
|
63fde4f095 | ||
|
|
d9bb9125c1 | ||
|
|
384d9f59a3 | ||
|
|
fc0fa7e8f0 | ||
|
|
a1054d06ca | ||
|
|
c2570a7a7c | ||
|
|
97f4128bfd | ||
|
|
2434dc8f92 | ||
|
|
123d61a888 | ||
|
|
53f6f4a0c0 | ||
|
|
550bef230a | ||
|
|
5a998d36b2 | ||
|
|
72cd199efc | ||
|
|
a1d993deb1 | ||
|
|
e546e21d53 | ||
|
|
26d6426156 | ||
|
|
8dffedebd6 | ||
|
|
60adf8d6e4 | ||
|
|
d13a1ad5f5 | ||
|
|
1e5f8a494a | ||
|
|
5216131769 | ||
|
|
8bdaf858b8 | ||
|
|
c37a0ca672 | ||
|
|
266cd15511 | ||
|
|
9debf8144e | ||
|
|
78ce0ed337 | ||
|
|
4aa1932bea | ||
|
|
b658295b97 | ||
|
|
8c59b6a026 | ||
|
|
e35b43a7a7 | ||
|
|
7288d914a8 | ||
|
|
1b487e261a | ||
|
|
3934663db9 | ||
|
|
fb639cb49c | ||
|
|
1856387e9e | ||
|
|
a5ad775a90 | ||
|
|
a321401683 | ||
|
|
8839220a00 | ||
|
|
e6b2ca4da3 | ||
|
|
d0c52d1dec | ||
|
|
a5fa6d1c43 | ||
|
|
7f79bd6e04 | ||
|
|
339985e39e | ||
|
|
f4ecd749d5 | ||
|
|
cb61c6b4bf | ||
|
|
b42c2c6cd6 | ||
|
|
da6633bf0d | ||
|
|
0193d18bec | ||
|
|
0a82192e36 | ||
|
|
202f6fef95 | ||
|
|
c49416e908 | ||
|
|
ec93ea6240 | ||
|
|
add20dc9a8 | ||
|
|
7799474746 | ||
|
|
d98c1f115f | ||
|
|
d97f70def4 | ||
|
|
609c6b0963 |
@@ -39,7 +39,6 @@ lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
./scripts/check_pydantic.sh .
|
||||
./scripts/lint_imports.sh
|
||||
poetry run ruff check .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
|
||||
@@ -4,10 +4,12 @@ import contextlib
|
||||
import mimetypes
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Generator, List, Literal, Mapping, Optional, Union, cast
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, root_validator
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
PathLike = Union[str, PurePath]
|
||||
|
||||
@@ -110,9 +112,10 @@ class Blob(BaseMedia):
|
||||
path: Optional[PathLike] = None
|
||||
"""Location where the original content was found."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def source(self) -> Optional[str]:
|
||||
@@ -128,7 +131,7 @@ class Blob(BaseMedia):
|
||||
return str(self.path) if self.path else None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
@@ -293,3 +296,7 @@ class Document(BaseMedia):
|
||||
return f"page_content='{self.page_content}' metadata={self.metadata}"
|
||||
else:
|
||||
return f"page_content='{self.page_content}'"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
@@ -95,7 +95,11 @@ class BaseLanguageModel(
|
||||
|
||||
Caching is not currently supported for streaming methods of models.
|
||||
"""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
# Repr = False is consistent with pydantic 1 if verbose = False
|
||||
# We can relax this for pydantic 2?
|
||||
# TODO(Team): decide what to do here.
|
||||
# Modified just to get unit tests to pass.
|
||||
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
@@ -108,6 +112,9 @@ class BaseLanguageModel(
|
||||
)
|
||||
"""Optional encoder to use for counting tokens."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("verbose", pre=True, always=True, allow_reuse=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
@@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
@@ -23,6 +24,13 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SkipValidation,
|
||||
root_validator,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@@ -55,11 +63,6 @@ from langchain_core.outputs import (
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
@@ -208,16 +211,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
callback_manager: Optional[BaseCallbackManager] = deprecated(
|
||||
name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
|
||||
)(
|
||||
Field(
|
||||
default=None,
|
||||
exclude=True,
|
||||
description="Callback manager to add to the run trace.",
|
||||
)
|
||||
)
|
||||
|
||||
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
||||
"An optional rate limiter to use for limiting the number of requests."
|
||||
|
||||
@@ -254,8 +247,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict # pydantic: ignore
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
@@ -80,7 +81,7 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
Exception: If the key is not in the model.
|
||||
"""
|
||||
try:
|
||||
return model.__fields__[key].get_default() != value
|
||||
return model.model_fields[key].get_default() != value
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
@@ -161,16 +162,25 @@ class Serializable(BaseModel, ABC):
|
||||
For example, for the class `langchain.llms.openai.OpenAI`, the id is
|
||||
["langchain", "llms", "openai", "OpenAI"].
|
||||
"""
|
||||
return [*cls.get_lc_namespace(), cls.__name__]
|
||||
# Pydantic generics change the class name. So we need to do the following
|
||||
if (
|
||||
"origin" in cls.__pydantic_generic_metadata__
|
||||
and cls.__pydantic_generic_metadata__["origin"] is not None
|
||||
):
|
||||
original_name = cls.__pydantic_generic_metadata__["origin"].__name__
|
||||
else:
|
||||
original_name = cls.__name__
|
||||
return [*cls.get_lc_namespace(), original_name]
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
model_config = ConfigDict(
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or try_neq_default(v, k, self))
|
||||
if (k not in self.model_fields or try_neq_default(v, k, self))
|
||||
]
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
@@ -184,12 +194,15 @@ class Serializable(BaseModel, ABC):
|
||||
|
||||
secrets = dict()
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {
|
||||
k: getattr(self, k, v)
|
||||
for k, v in self
|
||||
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
|
||||
and _is_field_useful(self, k, v)
|
||||
}
|
||||
lc_kwargs = {}
|
||||
for k, v in self:
|
||||
if not _is_field_useful(self, k, v):
|
||||
continue
|
||||
# Do nothing if the field is excluded
|
||||
if k in self.model_fields and self.model_fields[k].exclude:
|
||||
continue
|
||||
|
||||
lc_kwargs[k] = getattr(self, k, v)
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
for cls in [None, *self.__class__.mro()]:
|
||||
@@ -221,8 +234,10 @@ class Serializable(BaseModel, ABC):
|
||||
# that are not present in the fields.
|
||||
for key in list(secrets):
|
||||
value = secrets[key]
|
||||
if key in this.__fields__:
|
||||
secrets[this.__fields__[key].alias] = value
|
||||
if key in this.model_fields:
|
||||
alias = this.model_fields[key].alias
|
||||
if alias is not None:
|
||||
secrets[alias] = value
|
||||
lc_kwargs.update(this.lc_attributes)
|
||||
|
||||
# include all secrets, even if not specified in kwargs
|
||||
@@ -244,6 +259,10 @@ class Serializable(BaseModel, ABC):
|
||||
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
|
||||
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
"""Check if a field is useful as a constructor argument.
|
||||
@@ -259,10 +278,26 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
If the field is not required and the value is None, it is useful if the
|
||||
default value is different from the value.
|
||||
"""
|
||||
field = inst.__fields__.get(key)
|
||||
field = inst.model_fields.get(key)
|
||||
if not field:
|
||||
return False
|
||||
return field.required is True or value or field.get_default() != value
|
||||
|
||||
if field.is_required():
|
||||
return True
|
||||
|
||||
if value:
|
||||
return True
|
||||
|
||||
# Value is still falsy here!
|
||||
if field.default_factory is dict and isinstance(value, dict):
|
||||
return False
|
||||
|
||||
# Value is still falsy here!
|
||||
if field.default_factory is list and isinstance(value, list):
|
||||
return False
|
||||
|
||||
# If value is falsy and does not match the default
|
||||
return field.get_default() != value
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
|
||||
@@ -7,6 +7,7 @@ from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
@@ -108,6 +109,10 @@ class BaseMessage(Serializable):
|
||||
def pretty_print(self) -> None:
|
||||
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
|
||||
def merge_content(
|
||||
first_content: Union[str, List[Union[str, Dict]]],
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||
|
||||
|
||||
@@ -70,6 +71,11 @@ class ToolMessage(BaseMessage):
|
||||
.. versionadded:: 0.2.24
|
||||
"""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict, repr=False)
|
||||
"""Currently inherited from BaseMessage, but not used."""
|
||||
response_metadata: dict = Field(default_factory=dict, repr=False)
|
||||
"""Currently inherited from BaseMessage, but not used."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
@@ -13,8 +13,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain_core.language_models import LanguageModelOutput
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
@@ -166,10 +164,11 @@ class BaseOutputParser(
|
||||
Raises:
|
||||
TypeError: If the class doesn't have an inferable OutputType.
|
||||
"""
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 1:
|
||||
return type_args[0]
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) > 0:
|
||||
return metadata["args"][0]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||
|
||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, Optional, Type, TypeVar, Union
|
||||
from typing import Annotated, Any, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
import pydantic # pydantic: ignore
|
||||
from pydantic import SkipValidation # pydantic: ignore
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
@@ -40,7 +41,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
|
||||
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore
|
||||
"""The Pydantic object to use for validation.
|
||||
If None, no validation is performed."""
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
@@ -122,6 +123,9 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
||||
yield [part]
|
||||
|
||||
|
||||
ListOutputParser.update_forward_refs()
|
||||
|
||||
|
||||
class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import (
|
||||
@@ -11,7 +12,6 @@ from langchain_core.output_parsers import (
|
||||
)
|
||||
from langchain_core.output_parsers.json import parse_partial_json
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
@@ -263,11 +263,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
"""
|
||||
_result = super().parse_result(result)
|
||||
if self.args_only:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
if hasattr(self.pydantic_schema, "model_validate_json"):
|
||||
pydantic_args = self.pydantic_schema.model_validate_json(_result)
|
||||
else:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
else:
|
||||
fn_name = _result["name"]
|
||||
_args = _result["arguments"]
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
|
||||
if hasattr(self.pydantic_schema, "model_validate_json"):
|
||||
pydantic_args = self.pydantic_schema[fn_name].model_validate_json(_args)
|
||||
else:
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
|
||||
return pydantic_args
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import copy
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Annotated, Any, Dict, List, Optional
|
||||
|
||||
from pydantic import SkipValidation, ValidationError # pydantic: ignore
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
@@ -13,7 +15,6 @@ from langchain_core.messages.tool import (
|
||||
)
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import ValidationError
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
@@ -256,7 +257,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
tools: List[TypeBaseModel]
|
||||
tools: Annotated[List[TypeBaseModel], SkipValidation()]
|
||||
"""The tools to parse."""
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
from typing import Generic, List, Type
|
||||
from typing import Annotated, Generic, List, Type
|
||||
from typing import Optional as Optional
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
from pydantic import SkipValidation # pydantic: ignore
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
@@ -16,7 +18,7 @@ from langchain_core.utils.pydantic import (
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
"""Parse an output using a pydantic model."""
|
||||
|
||||
pydantic_object: Type[TBaseModel] # type: ignore
|
||||
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
|
||||
"""The pydantic model to parse."""
|
||||
|
||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||
@@ -106,6 +108,9 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return self.pydantic_object
|
||||
|
||||
|
||||
PydanticOutputParser.model_rebuild()
|
||||
|
||||
|
||||
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import List
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
|
||||
@@ -24,3 +25,6 @@ class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
def parse(self, text: str) -> str:
|
||||
"""Returns the input text with no changes."""
|
||||
return text
|
||||
|
||||
|
||||
StrOutputParser.model_rebuild()
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
|
||||
from langchain_core.outputs.generation import Generation, GenerationChunk
|
||||
from langchain_core.outputs.run_info import RunInfo
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
@@ -16,7 +18,9 @@ class LLMResult(BaseModel):
|
||||
wants to return.
|
||||
"""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
generations: List[
|
||||
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
|
||||
]
|
||||
"""Generated outputs.
|
||||
|
||||
The first dimension of the list represents completions for different input
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
@@ -21,6 +22,13 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
PositiveInt,
|
||||
SkipValidation,
|
||||
root_validator,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import (
|
||||
@@ -38,7 +46,6 @@ from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.image import ImagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
||||
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
@@ -207,8 +214,14 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
||||
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
|
||||
def __init__(
|
||||
self, variable_name: str, *, optional: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
# mypy can't detect the init which is defined in the parent class
|
||||
# b/c these are BaseModel classes.
|
||||
super().__init__( # type: ignore
|
||||
variable_name=variable_name, optional=optional, **kwargs
|
||||
)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
@@ -922,7 +935,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
messages: List[MessageLike]
|
||||
messages: Annotated[List[MessageLike], SkipValidation]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
@@ -106,3 +107,6 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
PipelinePromptTemplate.update_forward_refs()
|
||||
|
||||
@@ -35,7 +35,8 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal, get_args
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
||||
from typing_extensions import Literal, get_args, get_type_hints
|
||||
|
||||
from langchain_core._api import beta_decorator
|
||||
from langchain_core.load.dump import dumpd
|
||||
@@ -44,7 +45,6 @@ from langchain_core.load.serializable import (
|
||||
SerializedConstructor,
|
||||
SerializedNotImplemented,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
_set_config_context,
|
||||
@@ -83,7 +83,6 @@ from langchain_core.runnables.utils import (
|
||||
)
|
||||
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
@@ -236,25 +235,56 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/
|
||||
""" # noqa: E501
|
||||
|
||||
name: Optional[str] = None
|
||||
name: Optional[str]
|
||||
"""The name of the Runnable. Used for debugging and tracing."""
|
||||
|
||||
def get_name(
|
||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get the name of the Runnable."""
|
||||
name = name or self.name or self.__class__.__name__
|
||||
if suffix:
|
||||
if name[0].isupper():
|
||||
return name + suffix.title()
|
||||
else:
|
||||
return name + "_" + suffix.lower()
|
||||
if name:
|
||||
name_ = name
|
||||
elif hasattr(self, "name") and self.name:
|
||||
name_ = self.name
|
||||
else:
|
||||
return name
|
||||
# Here we handle a case where the runnable subclass is also a pydantic
|
||||
# model.
|
||||
cls = self.__class__
|
||||
# Then it's a pydantic sub-class, and we have to check
|
||||
# whether it's a generic, and if so recover the original name.
|
||||
if (
|
||||
hasattr(
|
||||
cls,
|
||||
"__pydantic_generic_metadata__",
|
||||
)
|
||||
and "origin" in cls.__pydantic_generic_metadata__
|
||||
and cls.__pydantic_generic_metadata__["origin"] is not None
|
||||
):
|
||||
name_ = cls.__pydantic_generic_metadata__["origin"].__name__
|
||||
else:
|
||||
name_ = cls.__name__
|
||||
|
||||
if suffix:
|
||||
if name_[0].isupper():
|
||||
return name_ + suffix.title()
|
||||
else:
|
||||
return name_ + "_" + suffix.lower()
|
||||
else:
|
||||
return name_
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
"""The type of input this Runnable accepts specified as a type annotation."""
|
||||
# First loop through bases -- this will help generic
|
||||
# any pydantic models.
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) == 2:
|
||||
return metadata["args"][0]
|
||||
|
||||
# then loop through __orig_bases__ -- this will Runnables that do not inherit
|
||||
# from pydantic
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 2:
|
||||
@@ -268,6 +298,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
"""The type of output this Runnable produces specified as a type annotation."""
|
||||
# First loop through bases -- this will help generic
|
||||
# any pydantic models.
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) == 2:
|
||||
return metadata["args"][1]
|
||||
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 2:
|
||||
@@ -302,12 +340,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.InputType
|
||||
|
||||
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=(root_type, None),
|
||||
__root__=root_type,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -334,12 +372,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.OutputType
|
||||
|
||||
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=(root_type, None),
|
||||
__root__=root_type,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -381,15 +419,19 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
else None
|
||||
)
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.get_name("Config"),
|
||||
# Many need to create a typed dict instead to implement NotRequired!
|
||||
all_fields = {
|
||||
**({"configurable": (configurable, None)} if configurable else {}),
|
||||
**{
|
||||
field_name: (field_type, None)
|
||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||
for field_name, field_type in get_type_hints(RunnableConfig).items()
|
||||
if field_name in [i for i in include if i != "configurable"]
|
||||
},
|
||||
}
|
||||
model = create_model( # type: ignore[call-overload]
|
||||
self.get_name("Config"), **all_fields
|
||||
)
|
||||
return model
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
"""Return a graph representation of this Runnable."""
|
||||
@@ -579,7 +621,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
from langchain_core.runnables.passthrough import RunnableAssign
|
||||
|
||||
return self | RunnableAssign(RunnableParallel(kwargs))
|
||||
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@@ -2129,7 +2171,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
iterator_ = None
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
if accepts_config(transformer):
|
||||
@@ -2314,7 +2355,6 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
"""Runnable that can be serialized to JSON."""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""The name of the Runnable. Used for debugging and tracing."""
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
"""Serialize the Runnable to JSON.
|
||||
@@ -2369,10 +2409,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
from langchain_core.runnables.configurable import RunnableConfigurableFields
|
||||
|
||||
for key in kwargs:
|
||||
if key not in self.__fields__:
|
||||
if key not in self.model_fields:
|
||||
raise ValueError(
|
||||
f"Configuration key {key} not found in {self}: "
|
||||
f"available keys are {self.__fields__.keys()}"
|
||||
f"available keys are {self.model_fields.keys()}"
|
||||
)
|
||||
|
||||
return RunnableConfigurableFields(default=self, fields=kwargs)
|
||||
@@ -2447,13 +2487,13 @@ def _seq_input_schema(
|
||||
return first.get_input_schema(config)
|
||||
elif isinstance(first, RunnableAssign):
|
||||
next_input_schema = _seq_input_schema(steps[1:], config)
|
||||
if not next_input_schema.__custom_root_type__:
|
||||
if not issubclass(next_input_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceInput",
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in next_input_schema.__fields__.items()
|
||||
for k, v in next_input_schema.model_fields.items()
|
||||
if k not in first.mapper.steps__
|
||||
},
|
||||
)
|
||||
@@ -2474,36 +2514,36 @@ def _seq_output_schema(
|
||||
elif isinstance(last, RunnableAssign):
|
||||
mapper_output_schema = last.mapper.get_output_schema(config)
|
||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||
if not prev_output_schema.__custom_root_type__:
|
||||
if not issubclass(prev_output_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
**{
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in prev_output_schema.__fields__.items()
|
||||
for k, v in prev_output_schema.model_fields.items()
|
||||
},
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in mapper_output_schema.__fields__.items()
|
||||
for k, v in mapper_output_schema.model_fields.items()
|
||||
},
|
||||
},
|
||||
)
|
||||
elif isinstance(last, RunnablePick):
|
||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||
if not prev_output_schema.__custom_root_type__:
|
||||
if not issubclass(prev_output_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
if isinstance(last.keys, list):
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in prev_output_schema.__fields__.items()
|
||||
for k, v in prev_output_schema.model_fields.items()
|
||||
if k in last.keys
|
||||
},
|
||||
)
|
||||
else:
|
||||
field = prev_output_schema.__fields__[last.keys]
|
||||
field = prev_output_schema.model_fields[last.keys]
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
__root__=(field.annotation, field.default),
|
||||
@@ -2665,8 +2705,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
@@ -3403,8 +3444,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_name(
|
||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||
@@ -3451,7 +3493,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for step in self.steps__.values()
|
||||
for k, v in step.get_input_schema(config).__fields__.items()
|
||||
for k, v in step.get_input_schema(config).model_fields.items()
|
||||
if k != "__root__"
|
||||
},
|
||||
)
|
||||
@@ -3469,11 +3511,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
Returns:
|
||||
The output schema of the Runnable.
|
||||
"""
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.get_name("Output"),
|
||||
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
|
||||
)
|
||||
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
|
||||
return create_model(self.get_name("Output"), **fields)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
@@ -3883,6 +3922,8 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
atransform: Optional[
|
||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
|
||||
] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize a RunnableGenerator.
|
||||
|
||||
@@ -3910,9 +3951,9 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
try:
|
||||
self.name = func_for_name.__name__
|
||||
self.name = name or func_for_name.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
self.name = "RunnableGenerator"
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
@@ -4184,15 +4225,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if all(
|
||||
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
|
||||
):
|
||||
fields = {item[1:-1]: (Any, ...) for item in items}
|
||||
# It's a dict, lol
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||
)
|
||||
return create_model(self.get_name("Input"), **fields)
|
||||
else:
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=(List[Any], None),
|
||||
__root__=List[Any],
|
||||
)
|
||||
|
||||
if self.InputType != Any:
|
||||
@@ -4201,7 +4240,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||
**{key: (Any, ...) for key in dict_keys}, # type: ignore
|
||||
)
|
||||
|
||||
return super().get_input_schema(config)
|
||||
@@ -4730,8 +4769,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
@@ -4758,10 +4798,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
schema = self.bound.get_output_schema(config)
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=(
|
||||
List[schema], # type: ignore
|
||||
None,
|
||||
),
|
||||
__root__=List[schema], # type: ignore[valid-type]
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -4981,8 +5018,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
The type can be a pydantic model, or a type annotation (e.g., `List[str]`).
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -5318,7 +5356,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||
RunnableBindingBase.model_rebuild()
|
||||
|
||||
|
||||
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
|
||||
@@ -134,7 +134,17 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
runnable = coerce_to_runnable(runnable)
|
||||
_branches.append((condition, runnable))
|
||||
|
||||
super().__init__(branches=_branches, default=default_) # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
branches=_branches,
|
||||
default=default_,
|
||||
# Hard-coding a name here because RunnableBranch is a generic
|
||||
# and with pydantic 2, the class name with pydantic will capture
|
||||
# include the parameterized type, which is not what we want.
|
||||
# e.g., we'd get RunnableBranch[Input, Output] instead of RunnableBranch
|
||||
# for the name. This information is already captured in the
|
||||
# input and output types.
|
||||
name="RunnableBranch",
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -20,7 +20,8 @@ from typing import (
|
||||
)
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
@@ -58,8 +59,9 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
config: Optional[RunnableConfig] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -373,28 +375,33 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
Returns:
|
||||
List[ConfigurableFieldSpec]: The configuration specs.
|
||||
"""
|
||||
return get_unique_config_specs(
|
||||
[
|
||||
(
|
||||
# TODO(0.3): This change removes field_info which isn't needed in pydantic 2
|
||||
config_specs = []
|
||||
|
||||
for field_name, spec in self.fields.items():
|
||||
if isinstance(spec, ConfigurableField):
|
||||
config_specs.append(
|
||||
ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description
|
||||
or self.default.__fields__[field_name].field_info.description,
|
||||
or self.default.model_fields[field_name].description,
|
||||
annotation=spec.annotation
|
||||
or self.default.__fields__[field_name].annotation,
|
||||
or self.default.model_fields[field_name].annotation,
|
||||
default=getattr(self.default, field_name),
|
||||
is_shared=spec.is_shared,
|
||||
)
|
||||
if isinstance(spec, ConfigurableField)
|
||||
else make_options_spec(
|
||||
spec, self.default.__fields__[field_name].field_info.description
|
||||
)
|
||||
else:
|
||||
config_specs.append(
|
||||
make_options_spec(
|
||||
spec, self.default.model_fields[field_name].description
|
||||
)
|
||||
)
|
||||
for field_name, spec in self.fields.items()
|
||||
]
|
||||
+ list(self.default.config_specs)
|
||||
)
|
||||
|
||||
config_specs.extend(self.default.config_specs)
|
||||
|
||||
return get_unique_config_specs(config_specs)
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
@@ -436,7 +443,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
init_params = {
|
||||
k: v
|
||||
for k, v in self.default.__dict__.items()
|
||||
if k in self.default.__fields__
|
||||
if k in self.default.model_fields
|
||||
}
|
||||
return (
|
||||
self.default.__class__(**{**init_params, **configurable}),
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import (
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable as RunnableType
|
||||
@@ -235,7 +235,9 @@ def node_data_json(
|
||||
json = (
|
||||
{
|
||||
"type": "schema",
|
||||
"data": node.data.schema(),
|
||||
"data": node.data.model_json_schema(
|
||||
schema_generator=_IgnoreUnserializable
|
||||
),
|
||||
}
|
||||
if with_schemas
|
||||
else {
|
||||
|
||||
@@ -372,28 +372,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
super_schema = super().get_input_schema(config)
|
||||
if super_schema.__custom_root_type__ or not super_schema.schema().get(
|
||||
"properties"
|
||||
):
|
||||
from langchain_core.messages import BaseMessage
|
||||
# TODO(0.3): Verify that this change was correct
|
||||
# Not enough tests and unclear on why the previous implementation was
|
||||
# necessary.
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
fields: Dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
...,
|
||||
)
|
||||
elif self.input_messages_key:
|
||||
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
|
||||
else:
|
||||
fields["__root__"] = (Sequence[BaseMessage], ...)
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableWithChatHistoryInput",
|
||||
**fields,
|
||||
fields: Dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
...,
|
||||
)
|
||||
elif self.input_messages_key:
|
||||
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
|
||||
else:
|
||||
return super_schema
|
||||
fields["__root__"] = (Sequence[BaseMessage], ...)
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableWithChatHistoryInput",
|
||||
**fields,
|
||||
)
|
||||
|
||||
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
@@ -21,7 +21,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel, RootModel
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
Runnable,
|
||||
@@ -227,7 +228,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
A Runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
|
||||
def invoke(
|
||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
@@ -419,7 +420,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
if not map_input_schema.__custom_root_type__:
|
||||
if not issubclass(map_input_schema, RootModel):
|
||||
# ie. it's a dict
|
||||
return map_input_schema
|
||||
|
||||
@@ -430,20 +431,22 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
map_output_schema = self.mapper.get_output_schema(config)
|
||||
if (
|
||||
not map_input_schema.__custom_root_type__
|
||||
and not map_output_schema.__custom_root_type__
|
||||
if not issubclass(map_input_schema, RootModel) and not issubclass(
|
||||
map_output_schema, RootModel
|
||||
):
|
||||
# ie. both are dicts
|
||||
fields = {}
|
||||
|
||||
for name, field_info in map_input_schema.model_fields.items():
|
||||
fields[name] = (field_info.annotation, field_info.default)
|
||||
|
||||
for name, field_info in map_output_schema.model_fields.items():
|
||||
fields[name] = (field_info.annotation, field_info.default)
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableAssignOutput",
|
||||
**{
|
||||
k: (v.type_, v.default)
|
||||
for s in (map_input_schema, map_output_schema)
|
||||
for k, v in s.__fields__.items()
|
||||
},
|
||||
**fields,
|
||||
)
|
||||
elif not map_output_schema.__custom_root_type__:
|
||||
elif not issubclass(map_output_schema, RootModel):
|
||||
# ie. only map output is a dict
|
||||
# ie. input type is either unknown or inferred incorrectly
|
||||
return map_output_schema
|
||||
|
||||
@@ -28,12 +28,18 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, RootModel # pydantic: ignore
|
||||
from pydantic import create_model as _create_model_base # pydantic :ignore
|
||||
from pydantic.json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaMode,
|
||||
)
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
|
||||
from langchain_core.pydantic_v1 import create_model as _create_model_base
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
|
||||
Input = TypeVar("Input", contravariant=True)
|
||||
@@ -350,7 +356,7 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = IsFunctionArgDict()
|
||||
visitor.visit(tree)
|
||||
return list(visitor.keys) if visitor.keys else None
|
||||
return sorted(visitor.keys) if visitor.keys else None
|
||||
except (SyntaxError, TypeError, OSError, SystemError):
|
||||
return None
|
||||
|
||||
@@ -697,9 +703,57 @@ class _RootEventFilter:
|
||||
return include
|
||||
|
||||
|
||||
class _SchemaConfig(BaseConfig):
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
_SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
||||
|
||||
NO_DEFAULT = object()
|
||||
|
||||
|
||||
def create_base_class(
|
||||
name: str, type_: Any, default_: object = NO_DEFAULT
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a base class."""
|
||||
|
||||
def schema(
|
||||
cls: Type[BaseModel],
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
) -> Dict[str, Any]:
|
||||
# Complains about schema not being defined in superclass
|
||||
schema_ = super(cls, cls).schema( # type: ignore[misc]
|
||||
by_alias=by_alias, ref_template=ref_template
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
def model_json_schema(
|
||||
cls: Type[BaseModel],
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
mode: JsonSchemaMode = "validation",
|
||||
) -> Dict[str, Any]:
|
||||
# Complains about model_json_schema not being defined in superclass
|
||||
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
|
||||
by_alias=by_alias,
|
||||
ref_template=ref_template,
|
||||
schema_generator=schema_generator,
|
||||
mode=mode,
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
base_class_attributes = {
|
||||
"__annotations__": {"root": type_},
|
||||
"model_config": ConfigDict(arbitrary_types_allowed=True),
|
||||
"schema": classmethod(schema),
|
||||
"model_json_schema": classmethod(model_json_schema),
|
||||
"__module__": "langchain_core.runnables.utils",
|
||||
}
|
||||
|
||||
if default_ is not NO_DEFAULT:
|
||||
base_class_attributes["root"] = default_
|
||||
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
||||
return cast(Type[BaseModel], custom_root_type)
|
||||
|
||||
|
||||
def create_model(
|
||||
@@ -715,6 +769,21 @@ def create_model(
|
||||
Returns:
|
||||
Type[BaseModel]: The created model.
|
||||
"""
|
||||
|
||||
# Move this to caching path
|
||||
if "__root__" in field_definitions:
|
||||
if len(field_definitions) > 1:
|
||||
raise NotImplementedError(
|
||||
"When specifying __root__ no other "
|
||||
f"fields should be provided. Got {field_definitions}"
|
||||
)
|
||||
|
||||
arg = field_definitions["__root__"]
|
||||
if isinstance(arg, tuple):
|
||||
named_root_model = create_base_class(__model_name, arg[0], arg[1])
|
||||
else:
|
||||
named_root_model = create_base_class(__model_name, arg)
|
||||
return named_root_model
|
||||
try:
|
||||
return _create_model_cached(__model_name, **field_definitions)
|
||||
except TypeError:
|
||||
|
||||
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Visitor(ABC):
|
||||
@@ -127,7 +127,8 @@ class Comparison(FilterDirective):
|
||||
def __init__(
|
||||
self, comparator: Comparator, attribute: str, value: Any, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
comparator=comparator, attribute=attribute, value=value, **kwargs
|
||||
)
|
||||
|
||||
@@ -145,8 +146,11 @@ class Operation(FilterDirective):
|
||||
|
||||
def __init__(
|
||||
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any
|
||||
):
|
||||
super().__init__(operator=operator, arguments=arguments, **kwargs)
|
||||
) -> None:
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
operator=operator, arguments=arguments, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class StructuredQuery(Expr):
|
||||
@@ -165,5 +169,8 @@ class StructuredQuery(Expr):
|
||||
filter: Optional[FilterDirective],
|
||||
limit: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(query=query, filter=filter, limit=limit, **kwargs)
|
||||
) -> None:
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
query=query, filter=filter, limit=limit, **kwargs
|
||||
)
|
||||
|
||||
@@ -24,7 +24,9 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, TypeVar, get_args, get_origin
|
||||
==== BASE ====
|
||||
from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin
|
||||
==== BASE ====
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@@ -33,16 +35,26 @@ from langchain_core.callbacks import (
|
||||
CallbackManager,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import ToolCall, ToolMessage
|
||||
==== BASE ====
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
PromptTemplate,
|
||||
aformat_document,
|
||||
format_document,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
ValidationError,
|
||||
create_model,
|
||||
root_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
==== BASE ====
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
@@ -59,6 +71,7 @@ from langchain_core.utils.function_calling import (
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
_create_subset_model,
|
||||
get_fields,
|
||||
is_basemodel_subclass,
|
||||
is_pydantic_v1_subclass,
|
||||
is_pydantic_v2_subclass,
|
||||
@@ -204,20 +217,64 @@ def create_schema_from_function(
|
||||
"""
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Let's ignore `self` and `cls` arguments for class and instance methods
|
||||
if func.__qualname__ and "." in func.__qualname__:
|
||||
# Then it likely belongs in a class namespace
|
||||
in_class = True
|
||||
else:
|
||||
in_class = False
|
||||
|
||||
has_args = False
|
||||
has_kwargs = False
|
||||
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == param.VAR_POSITIONAL:
|
||||
has_args = True
|
||||
elif param.kind == param.VAR_KEYWORD:
|
||||
has_kwargs = True
|
||||
|
||||
inferred_model = validated.model # type: ignore
|
||||
filter_args = filter_args if filter_args is not None else FILTERED_ARGS
|
||||
for arg in filter_args:
|
||||
if arg in inferred_model.__fields__:
|
||||
del inferred_model.__fields__[arg]
|
||||
|
||||
if filter_args:
|
||||
filter_args_ = filter_args
|
||||
else:
|
||||
# Handle classmethods and instance methods
|
||||
existing_params: List[str] = list(sig.parameters.keys())
|
||||
if existing_params and existing_params[0] in ("self", "cls") and in_class:
|
||||
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
|
||||
else:
|
||||
filter_args_ = list(FILTERED_ARGS)
|
||||
|
||||
for existing_param in existing_params:
|
||||
if not include_injected and _is_injected_arg_type(
|
||||
sig.parameters[existing_param].annotation
|
||||
):
|
||||
filter_args_.append(existing_param)
|
||||
|
||||
description, arg_descriptions = _infer_arg_descriptions(
|
||||
func,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
)
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
valid_properties = _get_filtered_args(
|
||||
inferred_model, func, filter_args=filter_args, include_injected=include_injected
|
||||
)
|
||||
valid_properties = []
|
||||
for field in get_fields(inferred_model):
|
||||
if not has_args:
|
||||
if field == "args":
|
||||
continue
|
||||
if not has_kwargs:
|
||||
if field == "kwargs":
|
||||
continue
|
||||
|
||||
if field == "v__duplicate_kwargs": # Internal pydantic field
|
||||
continue
|
||||
|
||||
if field not in filter_args_:
|
||||
valid_properties.append(field)
|
||||
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema",
|
||||
inferred_model,
|
||||
@@ -274,7 +331,10 @@ class ChildTool(BaseTool):
|
||||
|
||||
You can provide few-shot examples as a part of the description.
|
||||
"""
|
||||
args_schema: Optional[TypeBaseModel] = None
|
||||
|
||||
args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field(
|
||||
default=None, description="The tool schema."
|
||||
)
|
||||
"""Pydantic model class to validate and parse the tool's input arguments.
|
||||
|
||||
Args schema should be either:
|
||||
@@ -416,7 +476,7 @@ class ChildTool(BaseTool):
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if input_args is not None:
|
||||
key_ = next(iter(input_args.__fields__.keys()))
|
||||
key_ = next(iter(get_fields(input_args).keys()))
|
||||
input_args.validate({key_: tool_input})
|
||||
return tool_input
|
||||
else:
|
||||
|
||||
@@ -2,14 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
FILTERED_ARGS,
|
||||
@@ -18,13 +15,28 @@ from langchain_core.tools.base import (
|
||||
create_schema_from_function,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
from pydantic import BaseModel, Field, SkipValidation
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
Annotated,
|
||||
)
|
||||
|
||||
|
||||
class StructuredTool(BaseTool):
|
||||
"""Tool that can operate on any number of inputs."""
|
||||
|
||||
description: str = ""
|
||||
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
|
||||
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
|
||||
..., description="The tool schema."
|
||||
)
|
||||
"""The input arguments' schema."""
|
||||
func: Optional[Callable[..., Any]]
|
||||
"""The function to run when the tool is called."""
|
||||
|
||||
@@ -11,7 +11,6 @@ from langsmith.schemas import RunBase as BaseRunV2
|
||||
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
|
||||
@@ -82,7 +81,8 @@ class LLMRun(BaseRun):
|
||||
"""Class for LLMRun."""
|
||||
|
||||
prompts: List[str]
|
||||
response: Optional[LLMResult] = None
|
||||
# Temporarily, remove but we will completely remove LLMRun
|
||||
# response: Optional[LLMResult] = None
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Run", removal="1.0")
|
||||
|
||||
@@ -22,11 +22,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
@@ -84,7 +84,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
removal="1.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_function(
|
||||
model: Type[BaseModel],
|
||||
model: Type,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
@@ -192,11 +192,13 @@ def convert_python_function_to_openai_function(
|
||||
|
||||
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
|
||||
visited: Dict = {}
|
||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||
|
||||
model = cast(
|
||||
Type[BaseModel],
|
||||
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
||||
)
|
||||
return convert_pydantic_to_openai_function(model)
|
||||
return convert_pydantic_to_openai_function(model) # type: ignore
|
||||
|
||||
|
||||
_MAX_TYPED_DICT_RECURSION = 25
|
||||
@@ -208,6 +210,9 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
visited: Dict,
|
||||
depth: int = 0,
|
||||
) -> Type:
|
||||
from pydantic.v1 import Field as Field_v1 # pydantic: ignore
|
||||
from pydantic.v1 import create_model as create_model_v1 # pydantic: ignore
|
||||
|
||||
if type_ in visited:
|
||||
return visited[type_]
|
||||
elif depth >= _MAX_TYPED_DICT_RECURSION:
|
||||
@@ -241,7 +246,7 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
field_kwargs["description"] = arg_desc
|
||||
else:
|
||||
pass
|
||||
fields[arg] = (new_arg_type, Field(**field_kwargs))
|
||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||
else:
|
||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||
arg_type, depth=depth + 1, visited=visited
|
||||
@@ -249,8 +254,8 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
field_kwargs = {"default": ...}
|
||||
if arg_desc := arg_descriptions.get(arg):
|
||||
field_kwargs["description"] = arg_desc
|
||||
fields[arg] = (new_arg_type, Field(**field_kwargs))
|
||||
model = create_model(typed_dict.__name__, **fields)
|
||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||
model = create_model_v1(typed_dict.__name__, **fields)
|
||||
model.__doc__ = description
|
||||
visited[typed_dict] = model
|
||||
return model
|
||||
|
||||
@@ -8,8 +8,9 @@ from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from pydantic import BaseModel, root_validator # pydantic: ignore
|
||||
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue # pydantic: ignore
|
||||
from pydantic_core import core_schema # pydantic: ignore
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
@@ -27,7 +28,6 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic.fields import FieldInfo as FieldInfoV1
|
||||
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
TypeBaseModel = Type[BaseModel]
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
@@ -146,7 +146,7 @@ def pre_init(func: Callable) -> Any:
|
||||
Dict[str, Any]: The values to initialize the model with.
|
||||
"""
|
||||
# Insert default values
|
||||
fields = cls.__fields__
|
||||
fields = cls.model_fields
|
||||
for name, field_info in fields.items():
|
||||
# Check if allow_population_by_field_name is enabled
|
||||
# If yes, then set the field name to the alias
|
||||
@@ -155,9 +155,13 @@ def pre_init(func: Callable) -> Any:
|
||||
if cls.Config.allow_population_by_field_name:
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
if hasattr(cls, "model_config"):
|
||||
if cls.model_config.get("populate_by_name"):
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
|
||||
if name not in values or values[name] is None:
|
||||
if not field_info.required:
|
||||
if not field_info.is_required():
|
||||
if field_info.default_factory is not None:
|
||||
values[name] = field_info.default_factory()
|
||||
else:
|
||||
@@ -169,6 +173,44 @@ def pre_init(func: Callable) -> Any:
|
||||
return wrapper
|
||||
|
||||
|
||||
class _IgnoreUnserializable(GenerateJsonSchema):
|
||||
"""A JSON schema generator that ignores unknown types.
|
||||
|
||||
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
|
||||
"""
|
||||
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: core_schema.CoreSchema, error_info: str
|
||||
) -> JsonSchemaValue:
|
||||
return {}
|
||||
|
||||
|
||||
def v1_repr(obj: BaseModel) -> str:
|
||||
"""Return the schema of the object as a string.
|
||||
|
||||
Get a repr for the pydantic object which is consistent with pydantic.v1.
|
||||
"""
|
||||
if not is_basemodel_instance(obj):
|
||||
raise TypeError(f"Expected a pydantic BaseModel, got {type(obj)}")
|
||||
repr_ = []
|
||||
for name, field in get_fields(obj).items():
|
||||
value = getattr(obj, name)
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
repr_.append(f"{name}={v1_repr(value)}")
|
||||
else:
|
||||
if not field.is_required():
|
||||
if not value:
|
||||
continue
|
||||
if field.default == value:
|
||||
continue
|
||||
|
||||
repr_.append(f"{name}={repr(value)}")
|
||||
|
||||
args = ", ".join(repr_)
|
||||
return f"{obj.__class__.__name__}({args})"
|
||||
|
||||
|
||||
def _create_subset_model_v1(
|
||||
name: str,
|
||||
model: Type[BaseModel],
|
||||
@@ -178,12 +220,20 @@ def _create_subset_model_v1(
|
||||
fn_description: Optional[str] = None,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
from langchain_core.pydantic_v1 import create_model
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import create_model # pydantic: ignore
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1 import create_model # type: ignore # pydantic: ignore
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
||||
)
|
||||
|
||||
fields = {}
|
||||
|
||||
for field_name in field_names:
|
||||
field = model.__fields__[field_name]
|
||||
# Using pydantic v1 so can access __fields__ as a dict.
|
||||
field = model.__fields__[field_name] # type: ignore
|
||||
t = (
|
||||
# this isn't perfect but should work for most functions
|
||||
field.outer_type_
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
|
||||
from itertools import chain
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@@ -256,7 +256,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retriever_error_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
|
||||
return self
|
||||
|
||||
|
||||
@@ -390,5 +391,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
|
||||
return self
|
||||
|
||||
@@ -9,7 +9,6 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.stubs import (
|
||||
AnyStr,
|
||||
_AnyIdAIMessage,
|
||||
_AnyIdAIMessageChunk,
|
||||
_AnyIdHumanMessage,
|
||||
@@ -70,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@@ -89,29 +88,23 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
|
||||
assert chunks == [
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"name": "move_file"}},
|
||||
id=AnyStr(),
|
||||
_AnyIdAIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
_AnyIdAIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'},
|
||||
},
|
||||
id=AnyStr(),
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"arguments": ","}},
|
||||
id=AnyStr(),
|
||||
_AnyIdAIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
_AnyIdAIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
|
||||
},
|
||||
id=AnyStr(),
|
||||
),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
from langchain_core.caches import InMemoryCache
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
from typing import Optional as Optional
|
||||
|
||||
|
||||
def test_rate_limit_invoke() -> None:
|
||||
@@ -219,6 +220,8 @@ class SerializableModel(GenericFakeChatModel):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
SerializableModel.model_rebuild()
|
||||
|
||||
|
||||
def test_serialization_with_rate_limiter() -> None:
|
||||
"""Test model serialization with rate limiter."""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module to test base parser implementations."""
|
||||
|
||||
from typing import List
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
@@ -46,6 +47,8 @@ def test_base_generation_parser() -> None:
|
||||
assert isinstance(content, str)
|
||||
return content.swapcase() # type: ignore
|
||||
|
||||
StrInvertCase.update_forward_refs()
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")]))
|
||||
chain = model | StrInvertCase()
|
||||
assert chain.invoke("") == "HeLLO"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
dict({
|
||||
'definitions': dict({
|
||||
'AIMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message from an AI.
|
||||
|
||||
@@ -44,8 +45,16 @@
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'invalid_tool_calls': dict({
|
||||
'default': list([
|
||||
@@ -57,8 +66,16 @@
|
||||
'type': 'array',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -74,6 +91,7 @@
|
||||
'type': 'array',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'ai',
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
'ai',
|
||||
@@ -82,7 +100,15 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
@@ -92,6 +118,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ChatMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': 'Message that can be assigned an arbitrary speaker (i.e. role).',
|
||||
'properties': dict({
|
||||
'additional_kwargs': dict({
|
||||
@@ -120,12 +147,28 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -136,6 +179,7 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'chat',
|
||||
'default': 'chat',
|
||||
'enum': list([
|
||||
'chat',
|
||||
@@ -152,6 +196,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'FunctionMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for passing the result of executing a tool back to a model.
|
||||
|
||||
@@ -189,8 +234,16 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
@@ -201,6 +254,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'function',
|
||||
'default': 'function',
|
||||
'enum': list([
|
||||
'function',
|
||||
@@ -217,6 +271,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'HumanMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message from a human.
|
||||
|
||||
@@ -273,18 +328,35 @@
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'human',
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
'human',
|
||||
@@ -300,24 +372,59 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'InvalidToolCall': dict({
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Args',
|
||||
'type': 'string',
|
||||
}),
|
||||
'error': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Error',
|
||||
'type': 'string',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'invalid_tool_call',
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
@@ -335,6 +442,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'SystemMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for priming AI behavior.
|
||||
|
||||
@@ -386,18 +494,35 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'system',
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
'system',
|
||||
@@ -413,20 +538,44 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ToolCall': dict({
|
||||
'description': '''
|
||||
Represents a request to call a tool.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"name": "foo",
|
||||
"args": {"a": 1},
|
||||
"id": "123"
|
||||
}
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'tool_call',
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
@@ -443,6 +592,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ToolMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for passing the result of executing a tool back to a model.
|
||||
|
||||
@@ -513,12 +663,28 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -538,6 +704,7 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'tool',
|
||||
'default': 'tool',
|
||||
'enum': list([
|
||||
'tool',
|
||||
@@ -554,6 +721,21 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'description': '''
|
||||
Usage metadata for a message, such as token counts.
|
||||
|
||||
This is a standard representation of token usage that is consistent across models.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
''',
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
@@ -620,6 +802,7 @@
|
||||
dict({
|
||||
'definitions': dict({
|
||||
'AIMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message from an AI.
|
||||
|
||||
@@ -661,8 +844,16 @@
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'invalid_tool_calls': dict({
|
||||
'default': list([
|
||||
@@ -674,8 +865,16 @@
|
||||
'type': 'array',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -691,6 +890,7 @@
|
||||
'type': 'array',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'ai',
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
'ai',
|
||||
@@ -699,7 +899,15 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
@@ -709,6 +917,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ChatMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': 'Message that can be assigned an arbitrary speaker (i.e. role).',
|
||||
'properties': dict({
|
||||
'additional_kwargs': dict({
|
||||
@@ -737,12 +946,28 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -753,6 +978,7 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'chat',
|
||||
'default': 'chat',
|
||||
'enum': list([
|
||||
'chat',
|
||||
@@ -769,6 +995,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'FunctionMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for passing the result of executing a tool back to a model.
|
||||
|
||||
@@ -806,8 +1033,16 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
@@ -818,6 +1053,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'function',
|
||||
'default': 'function',
|
||||
'enum': list([
|
||||
'function',
|
||||
@@ -834,6 +1070,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'HumanMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message from a human.
|
||||
|
||||
@@ -890,18 +1127,35 @@
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'human',
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
'human',
|
||||
@@ -917,24 +1171,59 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'InvalidToolCall': dict({
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Args',
|
||||
'type': 'string',
|
||||
}),
|
||||
'error': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Error',
|
||||
'type': 'string',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'invalid_tool_call',
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
@@ -952,6 +1241,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'SystemMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for priming AI behavior.
|
||||
|
||||
@@ -1003,18 +1293,35 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'system',
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
'system',
|
||||
@@ -1030,20 +1337,44 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ToolCall': dict({
|
||||
'description': '''
|
||||
Represents a request to call a tool.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"name": "foo",
|
||||
"args": {"a": 1},
|
||||
"id": "123"
|
||||
}
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'tool_call',
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
@@ -1060,6 +1391,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ToolMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for passing the result of executing a tool back to a model.
|
||||
|
||||
@@ -1130,12 +1462,28 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -1155,6 +1503,7 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'tool',
|
||||
'default': 'tool',
|
||||
'enum': list([
|
||||
'tool',
|
||||
@@ -1171,6 +1520,21 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'description': '''
|
||||
Usage metadata for a message, such as token counts.
|
||||
|
||||
This is a standard representation of token usage that is consistent across models.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
''',
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
@@ -1436,7 +1800,7 @@
|
||||
'type': 'constructor',
|
||||
})
|
||||
# ---
|
||||
# name: test_chat_prompt_w_msgs_placeholder_ser_des[placholder]
|
||||
# name: test_chat_prompt_w_msgs_placeholder_ser_des[placeholder]
|
||||
dict({
|
||||
'id': list([
|
||||
'langchain',
|
||||
|
||||
@@ -6,7 +6,9 @@ from typing import Any, List, Union
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core._api.deprecation import LangChainPendingDeprecationWarning
|
||||
from langchain_core._api.deprecation import (
|
||||
LangChainPendingDeprecationWarning,
|
||||
)
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -810,7 +812,7 @@ def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) ->
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", "foo"), MessagesPlaceholder("bar"), ("human", "baz")]
|
||||
)
|
||||
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placholder")
|
||||
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placeholder")
|
||||
assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar")
|
||||
assert dumpd(prompt) == snapshot(name="chat_prompt")
|
||||
assert load(dumpd(prompt)) == prompt
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from typing import Any, Dict, Type, Union, cast
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
from langchain_core.load.dump import dumps
|
||||
@@ -34,6 +35,9 @@ class FakeStructuredChatModel(FakeListChatModel):
|
||||
return "fake-messages-list-chat-model"
|
||||
|
||||
|
||||
FakeStructuredChatModel.update_forward_refs()
|
||||
|
||||
|
||||
def test_structured_prompt_pydantic() -> None:
|
||||
class OutputSchema(BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
"""Helper utilities for pydantic.
|
||||
|
||||
This module includes helper utilities to ease the migration from pydantic v1 to v2.
|
||||
|
||||
They're meant to be used in the following way:
|
||||
|
||||
1) Use utility code to help (selected) unit tests pass without modifications
|
||||
2) Upgrade the unit tests to match pydantic 2
|
||||
3) Stop using the utility code
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
# Function to replace allOf with $ref
|
||||
def _replace_all_of_with_ref(schema: Any) -> None:
|
||||
"""Replace allOf with $ref in the schema."""
|
||||
def replace_all_of_with_ref(schema: Any) -> None:
|
||||
if isinstance(schema, dict):
|
||||
# If the schema has an allOf key with a single item that contains a $ref
|
||||
if (
|
||||
@@ -30,13 +20,13 @@ def _replace_all_of_with_ref(schema: Any) -> None:
|
||||
# Recursively process nested schemas
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
_replace_all_of_with_ref(value)
|
||||
replace_all_of_with_ref(value)
|
||||
elif isinstance(schema, list):
|
||||
for item in schema:
|
||||
_replace_all_of_with_ref(item)
|
||||
replace_all_of_with_ref(item)
|
||||
|
||||
|
||||
def _remove_bad_none_defaults(schema: Any) -> None:
|
||||
def remove_all_none_default(schema: Any) -> None:
|
||||
"""Removing all none defaults.
|
||||
|
||||
Pydantic v1 did not generate these, but Pydantic v2 does.
|
||||
@@ -56,39 +46,48 @@ def _remove_bad_none_defaults(schema: Any) -> None:
|
||||
break # Null type explicitly defined
|
||||
else:
|
||||
del value["default"]
|
||||
_remove_bad_none_defaults(value)
|
||||
remove_all_none_default(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_remove_bad_none_defaults(item)
|
||||
remove_all_none_default(item)
|
||||
elif isinstance(schema, list):
|
||||
for item in schema:
|
||||
_remove_bad_none_defaults(item)
|
||||
remove_all_none_default(item)
|
||||
|
||||
|
||||
def _remove_enum_description(obj: Any) -> None:
|
||||
"""Remove the description from enums."""
|
||||
if isinstance(obj, dict):
|
||||
if "enum" in obj:
|
||||
if "description" in obj and obj["description"] == "An enumeration.":
|
||||
del obj["description"]
|
||||
for key, value in obj.items():
|
||||
_remove_enum_description(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
_remove_enum_description(item)
|
||||
|
||||
|
||||
def _schema(obj: Any) -> dict:
|
||||
"""Get the schema of a pydantic model in the pydantic v1 style.
|
||||
|
||||
This will attempt to map the schema as close as possible to the pydantic v1 schema.
|
||||
"""
|
||||
"""Return the schema of the object."""
|
||||
# Remap to old style schema
|
||||
if not is_basemodel_subclass(obj):
|
||||
raise TypeError(
|
||||
f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
|
||||
)
|
||||
if not hasattr(obj, "model_json_schema"): # V1 model
|
||||
return obj.schema()
|
||||
|
||||
# Then we're using V2 models internally.
|
||||
raise AssertionError(
|
||||
"Hi there! Looks like you're attempting to upgrade to Pydantic v2. If so: \n"
|
||||
"1) remove this exception\n"
|
||||
"2) confirm that the old unit tests pass, and if not look for difference\n"
|
||||
"3) update the unit tests to match the new schema\n"
|
||||
"4) remove this utility function\n"
|
||||
)
|
||||
|
||||
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
|
||||
if "$defs" in schema_:
|
||||
schema_["definitions"] = schema_["$defs"]
|
||||
del schema_["$defs"]
|
||||
|
||||
_replace_all_of_with_ref(schema_)
|
||||
_remove_bad_none_defaults(schema_)
|
||||
if "default" in schema_ and schema_["default"] is None:
|
||||
del schema_["default"]
|
||||
|
||||
replace_all_of_with_ref(schema_)
|
||||
remove_all_none_default(schema_)
|
||||
_remove_enum_description(schema_)
|
||||
|
||||
return schema_
|
||||
|
||||
@@ -334,31 +334,9 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/HumanMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ChatMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/SystemMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/FunctionMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ToolMessage',
|
||||
}),
|
||||
]),
|
||||
'definitions': dict({
|
||||
'$defs': dict({
|
||||
'AIMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message from an AI.
|
||||
|
||||
@@ -400,21 +378,37 @@
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'invalid_tool_calls': dict({
|
||||
'default': list([
|
||||
]),
|
||||
'items': dict({
|
||||
'$ref': '#/definitions/InvalidToolCall',
|
||||
'$ref': '#/$defs/InvalidToolCall',
|
||||
}),
|
||||
'title': 'Invalid Tool Calls',
|
||||
'type': 'array',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -424,12 +418,13 @@
|
||||
'default': list([
|
||||
]),
|
||||
'items': dict({
|
||||
'$ref': '#/definitions/ToolCall',
|
||||
'$ref': '#/$defs/ToolCall',
|
||||
}),
|
||||
'title': 'Tool Calls',
|
||||
'type': 'array',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'ai',
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
'ai',
|
||||
@@ -438,7 +433,15 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'$ref': '#/$defs/UsageMetadata',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
@@ -448,6 +451,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ChatMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': 'Message that can be assigned an arbitrary speaker (i.e. role).',
|
||||
'properties': dict({
|
||||
'additional_kwargs': dict({
|
||||
@@ -476,12 +480,28 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -492,6 +512,7 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'chat',
|
||||
'default': 'chat',
|
||||
'enum': list([
|
||||
'chat',
|
||||
@@ -508,6 +529,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'FunctionMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for passing the result of executing a tool back to a model.
|
||||
|
||||
@@ -545,8 +567,16 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
@@ -557,6 +587,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'function',
|
||||
'default': 'function',
|
||||
'enum': list([
|
||||
'function',
|
||||
@@ -573,6 +604,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'HumanMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message from a human.
|
||||
|
||||
@@ -629,18 +661,35 @@
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'human',
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
'human',
|
||||
@@ -656,24 +705,59 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'InvalidToolCall': dict({
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Args',
|
||||
'type': 'string',
|
||||
}),
|
||||
'error': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Error',
|
||||
'type': 'string',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'invalid_tool_call',
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
@@ -691,6 +775,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'SystemMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for priming AI behavior.
|
||||
|
||||
@@ -742,18 +827,35 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'system',
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
'system',
|
||||
@@ -769,20 +871,44 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ToolCall': dict({
|
||||
'description': '''
|
||||
Represents a request to call a tool.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"name": "foo",
|
||||
"args": {"a": 1},
|
||||
"id": "123"
|
||||
}
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'tool_call',
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
@@ -799,6 +925,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'ToolMessage': dict({
|
||||
'additionalProperties': True,
|
||||
'description': '''
|
||||
Message for passing the result of executing a tool back to a model.
|
||||
|
||||
@@ -845,6 +972,7 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'artifact': dict({
|
||||
'default': None,
|
||||
'title': 'Artifact',
|
||||
}),
|
||||
'content': dict({
|
||||
@@ -869,12 +997,28 @@
|
||||
'title': 'Content',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Id',
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'type': 'null',
|
||||
}),
|
||||
]),
|
||||
'default': None,
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'response_metadata': dict({
|
||||
'title': 'Response Metadata',
|
||||
@@ -894,6 +1038,7 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'const': 'tool',
|
||||
'default': 'tool',
|
||||
'enum': list([
|
||||
'tool',
|
||||
@@ -910,6 +1055,21 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'description': '''
|
||||
Usage metadata for a message, such as token counts.
|
||||
|
||||
This is a standard representation of token usage that is consistent across models.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
''',
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
@@ -933,6 +1093,29 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/AIMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/HumanMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/ChatMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/SystemMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/FunctionMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/$defs/ToolMessage',
|
||||
}),
|
||||
]),
|
||||
'title': 'RunnableParallel<as_list,as_str>Input',
|
||||
}),
|
||||
'id': 3,
|
||||
@@ -952,6 +1135,10 @@
|
||||
'title': 'As Str',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'as_list',
|
||||
'as_str',
|
||||
]),
|
||||
'title': 'RunnableParallel<as_list,as_str>Output',
|
||||
'type': 'object',
|
||||
}),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -455,8 +455,9 @@ def test_get_input_schema_input_dict() -> None:
|
||||
|
||||
|
||||
def test_get_input_schema_input_messages() -> None:
|
||||
class RunnableWithChatHistoryInput(BaseModel):
|
||||
__root__: Sequence[BaseMessage]
|
||||
from pydantic import RootModel # pydantic: ignore
|
||||
|
||||
RunnableWithMessageHistoryInput = RootModel[Sequence[BaseMessage]]
|
||||
|
||||
runnable = RunnableLambda(
|
||||
lambda messages: {
|
||||
@@ -478,9 +479,9 @@ def test_get_input_schema_input_messages() -> None:
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable, get_session_history, output_messages_key="output"
|
||||
)
|
||||
assert _schema(with_history.get_input_schema()) == _schema(
|
||||
RunnableWithChatHistoryInput
|
||||
)
|
||||
expected_schema = _schema(RunnableWithMessageHistoryInput)
|
||||
expected_schema["title"] = "RunnableWithChatHistoryInput"
|
||||
assert _schema(with_history.get_input_schema()) == expected_schema
|
||||
|
||||
|
||||
def test_using_custom_config_specs() -> None:
|
||||
|
||||
@@ -19,6 +19,7 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from pydantic import BaseModel
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
from typing_extensions import TypedDict
|
||||
@@ -57,7 +58,6 @@ from langchain_core.prompts import (
|
||||
PromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
AddableDict,
|
||||
@@ -89,7 +89,7 @@ from langchain_core.tracers import (
|
||||
RunLogPatch,
|
||||
)
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
from tests.unit_tests.pydantic_utils import _schema, replace_all_of_with_ref
|
||||
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
|
||||
|
||||
@@ -313,12 +313,14 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"metadata": {"title": "Metadata", "type": "object"},
|
||||
"id": {
|
||||
"title": "Id",
|
||||
"type": "string",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": None,
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"enum": ["Document"],
|
||||
"default": "Document",
|
||||
"const": "Document",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
@@ -329,7 +331,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
|
||||
|
||||
assert _schema(fake_llm.input_schema) == snapshot
|
||||
assert _schema(fake_llm.input_schema) == snapshot(name="fake_llm_input_schema")
|
||||
assert _schema(fake_llm.output_schema) == {
|
||||
"title": "FakeListLLMOutput",
|
||||
"type": "string",
|
||||
@@ -337,8 +339,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
fake_chat = FakeListChatModel(responses=["a"]) # str -> List[List[str]]
|
||||
|
||||
assert _schema(fake_chat.input_schema) == snapshot
|
||||
assert _schema(fake_chat.output_schema) == snapshot
|
||||
assert _schema(fake_chat.input_schema) == snapshot(name="fake_chat_input_schema")
|
||||
assert _schema(fake_chat.output_schema) == snapshot(name="fake_chat_output_schema")
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
@@ -362,7 +364,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
assert _schema(prompt.output_schema) == snapshot
|
||||
assert _schema(prompt.output_schema) == snapshot(name="prompt_output_schema")
|
||||
|
||||
prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()
|
||||
|
||||
@@ -379,11 +381,15 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"type": "array",
|
||||
"title": "RunnableEach<PromptTemplate>Input",
|
||||
}
|
||||
assert _schema(prompt_mapper.output_schema) == snapshot
|
||||
assert _schema(prompt_mapper.output_schema) == snapshot(
|
||||
name="prompt_mapper_output_schema"
|
||||
)
|
||||
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
assert _schema(list_parser.input_schema) == snapshot
|
||||
assert _schema(list_parser.input_schema) == snapshot(
|
||||
name="list_parser_input_schema"
|
||||
)
|
||||
assert _schema(list_parser.output_schema) == {
|
||||
"title": "CommaSeparatedListOutputParserOutput",
|
||||
"type": "array",
|
||||
@@ -407,19 +413,26 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
router: Runnable = RouterRunnable({})
|
||||
|
||||
assert _schema(router.input_schema) == {
|
||||
"title": "RouterRunnableInput",
|
||||
"$ref": "#/definitions/RouterInput",
|
||||
"definitions": {
|
||||
"RouterInput": {
|
||||
"title": "RouterInput",
|
||||
"type": "object",
|
||||
"description": "Router input.\n"
|
||||
"\n"
|
||||
"Attributes:\n"
|
||||
" key: The key to route "
|
||||
"on.\n"
|
||||
" input: The input to pass "
|
||||
"to the selected Runnable.",
|
||||
"properties": {
|
||||
"key": {"title": "Key", "type": "string"},
|
||||
"input": {"title": "Input"},
|
||||
"key": {"title": "Key", "type": "string"},
|
||||
},
|
||||
"required": ["key", "input"],
|
||||
"title": "RouterInput",
|
||||
"type": "object",
|
||||
}
|
||||
},
|
||||
"title": "RouterRunnableInput",
|
||||
}
|
||||
assert _schema(router.output_schema) == {"title": "RouterRunnableOutput"}
|
||||
|
||||
@@ -451,6 +464,20 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["original", "as_list", "length"],
|
||||
}
|
||||
|
||||
# Add a test for schema of runnable assign
|
||||
def foo(x: int) -> int:
|
||||
return x
|
||||
|
||||
foo_ = RunnableLambda(foo)
|
||||
|
||||
assert foo_.assign(bar=lambda x: "foo").get_output_schema().schema() == {
|
||||
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
|
||||
"required": ["root", "bar"],
|
||||
"title": "RunnableAssignOutput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
@@ -469,6 +496,7 @@ def test_passthrough_assign_schema() -> None:
|
||||
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||
"title": "RunnableSequenceInput",
|
||||
"type": "object",
|
||||
"required": ["question"],
|
||||
}
|
||||
assert _schema(seq_w_assign.output_schema) == {
|
||||
"title": "FakeListLLMOutput",
|
||||
@@ -486,6 +514,7 @@ def test_passthrough_assign_schema() -> None:
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
"title": "RunnableParallel<context>Input",
|
||||
"type": "object",
|
||||
"required": ["question"],
|
||||
}
|
||||
|
||||
|
||||
@@ -498,6 +527,7 @@ def test_lambda_schemas() -> None:
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}},
|
||||
"required": ["hello"],
|
||||
}
|
||||
|
||||
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
|
||||
@@ -505,6 +535,7 @@ def test_lambda_schemas() -> None:
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
|
||||
"required": ["bye", "hello"],
|
||||
}
|
||||
|
||||
def get_value(input): # type: ignore[no-untyped-def]
|
||||
@@ -514,6 +545,7 @@ def test_lambda_schemas() -> None:
|
||||
"title": "get_value_input",
|
||||
"type": "object",
|
||||
"properties": {"variable_name": {"title": "Variable Name"}},
|
||||
"required": ["variable_name"],
|
||||
}
|
||||
|
||||
async def aget_value(input): # type: ignore[no-untyped-def]
|
||||
@@ -526,6 +558,7 @@ def test_lambda_schemas() -> None:
|
||||
"another": {"title": "Another"},
|
||||
"variable_name": {"title": "Variable Name"},
|
||||
},
|
||||
"required": ["another", "variable_name"],
|
||||
}
|
||||
|
||||
async def aget_values(input): # type: ignore[no-untyped-def]
|
||||
@@ -542,6 +575,7 @@ def test_lambda_schemas() -> None:
|
||||
"variable_name": {"title": "Variable Name"},
|
||||
"yo": {"title": "Yo"},
|
||||
},
|
||||
"required": ["variable_name", "yo"],
|
||||
}
|
||||
|
||||
class InputType(TypedDict):
|
||||
@@ -622,6 +656,25 @@ def test_with_types_with_type_generics() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_schema_with_itemgetter() -> None:
|
||||
"""Test runnable with itemgetter."""
|
||||
foo = RunnableLambda(itemgetter("hello"))
|
||||
assert _schema(foo.input_schema) == {
|
||||
"properties": {"hello": {"title": "Hello"}},
|
||||
"required": ["hello"],
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
}
|
||||
prompt = ChatPromptTemplate.from_template("what is {language}?")
|
||||
chain: Runnable = {"language": itemgetter("language")} | prompt
|
||||
assert _schema(chain.input_schema) == {
|
||||
"properties": {"language": {"title": "Language"}},
|
||||
"required": ["language"],
|
||||
"title": "RunnableParallel<language>Input",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_schema_complex_seq() -> None:
|
||||
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
|
||||
prompt2 = ChatPromptTemplate.from_template(
|
||||
@@ -650,6 +703,7 @@ def test_schema_complex_seq() -> None:
|
||||
"person": {"title": "Person", "type": "string"},
|
||||
"language": {"title": "Language"},
|
||||
},
|
||||
"required": ["person", "language"],
|
||||
}
|
||||
|
||||
assert _schema(chain2.output_schema) == {
|
||||
@@ -953,66 +1007,69 @@ def test_configurable_fields_prefix_keys() -> None:
|
||||
chain = prompt | fake_llm
|
||||
|
||||
assert _schema(chain.config_schema()) == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"LLM": {
|
||||
"title": "LLM",
|
||||
"description": "An enumeration.",
|
||||
"enum": ["chat", "default"],
|
||||
"type": "string",
|
||||
},
|
||||
"Chat_Responses": {
|
||||
"title": "Chat Responses",
|
||||
"description": "An enumeration.",
|
||||
"enum": ["hello", "bye", "helpful"],
|
||||
"type": "string",
|
||||
},
|
||||
"Prompt_Template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "An enumeration.",
|
||||
"enum": ["hello", "good_morning"],
|
||||
"title": "Chat Responses",
|
||||
"type": "string",
|
||||
},
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt_template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "The prompt template for this chain",
|
||||
"default": "hello",
|
||||
"allOf": [{"$ref": "#/definitions/Prompt_Template"}],
|
||||
"chat_sleep": {
|
||||
"anyOf": [{"type": "number"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"title": "Chat " "Sleep",
|
||||
},
|
||||
"llm": {
|
||||
"title": "LLM",
|
||||
"$ref": "#/definitions/LLM",
|
||||
"default": "default",
|
||||
"allOf": [{"$ref": "#/definitions/LLM"}],
|
||||
"title": "LLM",
|
||||
},
|
||||
# not prefixed because marked as shared
|
||||
"chat_sleep": {
|
||||
"title": "Chat Sleep",
|
||||
"type": "number",
|
||||
},
|
||||
# prefixed for "chat" option
|
||||
"llm==chat/responses": {
|
||||
"title": "Chat Responses",
|
||||
"default": ["hello", "bye"],
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/Chat_Responses"},
|
||||
},
|
||||
# prefixed for "default" option
|
||||
"llm==default/responses": {
|
||||
"title": "LLM Responses",
|
||||
"description": "A list of fake responses for this LLM",
|
||||
"default": ["a"],
|
||||
"title": "Chat " "Responses",
|
||||
"type": "array",
|
||||
},
|
||||
"llm==default/responses": {
|
||||
"default": ["a"],
|
||||
"description": "A "
|
||||
"list "
|
||||
"of "
|
||||
"fake "
|
||||
"responses "
|
||||
"for "
|
||||
"this "
|
||||
"LLM",
|
||||
"items": {"type": "string"},
|
||||
"title": "LLM " "Responses",
|
||||
"type": "array",
|
||||
},
|
||||
"prompt_template": {
|
||||
"$ref": "#/definitions/Prompt_Template",
|
||||
"default": "hello",
|
||||
"description": "The "
|
||||
"prompt "
|
||||
"template "
|
||||
"for "
|
||||
"this "
|
||||
"chain",
|
||||
"title": "Prompt " "Template",
|
||||
},
|
||||
},
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
},
|
||||
"LLM": {"enum": ["chat", "default"], "title": "LLM", "type": "string"},
|
||||
"Prompt_Template": {
|
||||
"enum": ["hello", "good_morning"],
|
||||
"title": "Prompt Template",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
@@ -1061,26 +1118,22 @@ def test_configurable_fields_example() -> None:
|
||||
chain_configurable = prompt | fake_llm | (lambda x: {"name": x}) | prompt | fake_llm
|
||||
|
||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||
|
||||
assert _schema(chain_configurable.config_schema()) == {
|
||||
expected = {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"LLM": {
|
||||
"title": "LLM",
|
||||
"description": "An enumeration.",
|
||||
"enum": ["chat", "default"],
|
||||
"type": "string",
|
||||
},
|
||||
"Chat_Responses": {
|
||||
"description": "An enumeration.",
|
||||
"enum": ["hello", "bye", "helpful"],
|
||||
"title": "Chat Responses",
|
||||
"type": "string",
|
||||
},
|
||||
"Prompt_Template": {
|
||||
"description": "An enumeration.",
|
||||
"enum": ["hello", "good_morning"],
|
||||
"title": "Prompt Template",
|
||||
"type": "string",
|
||||
@@ -1118,6 +1171,10 @@ def test_configurable_fields_example() -> None:
|
||||
},
|
||||
}
|
||||
|
||||
replace_all_of_with_ref(expected)
|
||||
|
||||
assert _schema(chain_configurable.config_schema()) == expected
|
||||
|
||||
assert (
|
||||
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
|
||||
{"name": "John"}
|
||||
@@ -3173,6 +3230,7 @@ def test_map_stream() -> None:
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
"llm": {"title": "Llm", "type": "string"},
|
||||
},
|
||||
"required": ["llm", "hello"],
|
||||
}
|
||||
|
||||
stream = chain_pick_two.stream({"question": "What is your name?"})
|
||||
@@ -3190,6 +3248,12 @@ def test_map_stream() -> None:
|
||||
{"llm": "i"},
|
||||
{"chat": AIMessageChunk(content="i")},
|
||||
]
|
||||
if not ( # TODO(Rewrite properly) statement above
|
||||
streamed_chunks[0] == {"llm": "i"}
|
||||
or {"chat": _AnyIdAIMessageChunk(content="i")}
|
||||
):
|
||||
raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}")
|
||||
|
||||
assert len(streamed_chunks) == len(llm_res) + len(chat_res)
|
||||
|
||||
|
||||
@@ -3544,6 +3608,7 @@ def test_deep_stream_assign() -> None:
|
||||
"str": {"title": "Str", "type": "string"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
"required": ["str", "hello"],
|
||||
}
|
||||
|
||||
chunks = []
|
||||
@@ -3595,6 +3660,7 @@ def test_deep_stream_assign() -> None:
|
||||
"str": {"title": "Str"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
"required": ["str", "hello"],
|
||||
}
|
||||
|
||||
chunks = []
|
||||
@@ -3670,6 +3736,7 @@ async def test_deep_astream_assign() -> None:
|
||||
"str": {"title": "Str", "type": "string"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
"required": ["str", "hello"],
|
||||
}
|
||||
|
||||
chunks = []
|
||||
@@ -3721,6 +3788,7 @@ async def test_deep_astream_assign() -> None:
|
||||
"str": {"title": "Str"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
"required": ["str", "hello"],
|
||||
}
|
||||
|
||||
chunks = []
|
||||
@@ -4362,7 +4430,10 @@ def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
|
||||
assert isinstance(body, Runnable)
|
||||
|
||||
assert isinstance(runnable.default, Runnable)
|
||||
assert _schema(runnable.input_schema) == {"title": "RunnableBranchInput"}
|
||||
assert _schema(runnable.input_schema) == {
|
||||
"title": "RunnableBranchInput",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
|
||||
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import sys
|
||||
from itertools import cycle
|
||||
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
|
||||
from typing import Optional as Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -1176,6 +1177,9 @@ class HardCodedRetriever(BaseRetriever):
|
||||
return self.documents
|
||||
|
||||
|
||||
HardCodedRetriever.update_forward_refs()
|
||||
|
||||
|
||||
async def test_event_stream_with_retriever() -> None:
|
||||
"""Test the event stream with a retriever."""
|
||||
retriever = HardCodedRetriever(
|
||||
|
||||
25
libs/core/tests/unit_tests/test_json_schema_remapping.py
Normal file
25
libs/core/tests/unit_tests/test_json_schema_remapping.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel, v1
|
||||
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def test_schemas() -> None:
|
||||
"""Test schema remapping for the two pydantic versions."""
|
||||
|
||||
class Bar(BaseModel):
|
||||
baz: int
|
||||
|
||||
class Foo(BaseModel):
|
||||
bar: Bar
|
||||
|
||||
schema_2 = _schema(Foo)
|
||||
|
||||
class Bar(v1.BaseModel): # type: ignore[no-redef]
|
||||
baz: int
|
||||
|
||||
class Foo(v1.BaseModel): # type: ignore[no-redef]
|
||||
bar: Bar
|
||||
|
||||
schema_1 = _schema(Foo)
|
||||
|
||||
assert schema_1 == schema_2
|
||||
@@ -874,15 +874,13 @@ async def test_async_validation_error_handling_non_validation_error(
|
||||
|
||||
def test_optional_subset_model_rewrite() -> None:
|
||||
class MyModel(BaseModel):
|
||||
a: Optional[str]
|
||||
a: Optional[str] = None
|
||||
b: str
|
||||
c: Optional[List[Optional[str]]]
|
||||
c: Optional[List[Optional[str]]] = None
|
||||
|
||||
model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])
|
||||
|
||||
assert "a" not in _schema(model2)["required"] # should be optional
|
||||
assert "b" in _schema(model2)["required"] # should be required
|
||||
assert "c" not in _schema(model2)["required"] # should be optional
|
||||
assert set(_schema(model2)["required"]) == {"b"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -411,17 +411,17 @@ def test_multiple_tool_calls() -> None:
|
||||
{
|
||||
"id": messages[2].tool_call_id,
|
||||
"type": "function",
|
||||
"function": {"name": "FakeCall", "arguments": '{"data": "ToolCall1"}'},
|
||||
"function": {"name": "FakeCall", "arguments": '{"data":"ToolCall1"}'},
|
||||
},
|
||||
{
|
||||
"id": messages[3].tool_call_id,
|
||||
"type": "function",
|
||||
"function": {"name": "FakeCall", "arguments": '{"data": "ToolCall2"}'},
|
||||
"function": {"name": "FakeCall", "arguments": '{"data":"ToolCall2"}'},
|
||||
},
|
||||
{
|
||||
"id": messages[4].tool_call_id,
|
||||
"type": "function",
|
||||
"function": {"name": "FakeCall", "arguments": '{"data": "ToolCall3"}'},
|
||||
"function": {"name": "FakeCall", "arguments": '{"data":"ToolCall3"}'},
|
||||
},
|
||||
]
|
||||
|
||||
@@ -442,7 +442,7 @@ def test_tool_outputs() -> None:
|
||||
{
|
||||
"id": messages[2].tool_call_id,
|
||||
"type": "function",
|
||||
"function": {"name": "FakeCall", "arguments": '{"data": "ToolCall1"}'},
|
||||
"function": {"name": "FakeCall", "arguments": '{"data":"ToolCall1"}'},
|
||||
},
|
||||
]
|
||||
assert messages[2].content == "Output1"
|
||||
@@ -698,7 +698,8 @@ def test__convert_typed_dict_to_openai_function(
|
||||
@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
|
||||
def test__convert_typed_dict_to_openai_function_fail(typed_dict: Type) -> None:
|
||||
class Tool(typed_dict):
|
||||
arg1: MutableSet # Pydantic doesn't support
|
||||
arg1: MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
|
||||
|
||||
# Error should be raised since we're using v1 code path here
|
||||
with pytest.raises(TypeError):
|
||||
_convert_typed_dict_to_openai_function(Tool)
|
||||
|
||||
@@ -117,7 +117,7 @@ async def test_inmemory_filter() -> None:
|
||||
|
||||
# Check sync version
|
||||
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
|
||||
assert output == [Document(page_content="foo", metadata={"id": 1}, id=AnyStr())]
|
||||
assert output == [_AnyIdDocument(page_content="foo", metadata={"id": 1})]
|
||||
|
||||
# filter with not stored document id
|
||||
output = await store.asimilarity_search(
|
||||
|
||||
Reference in New Issue
Block a user