Compare commits

...

97 Commits

Author SHA1 Message Date
Eugene Yurtsev
3981cda448 Core upgraded to pydantic 2 2024-09-03 14:54:21 -04:00
Erick Friis
979505eb03 x 2024-08-08 09:21:33 -07:00
Erick Friis
1882a139b7 x 2024-08-08 09:18:51 -07:00
Eugene Yurtsev
627757e808 Bump pydantic 2024-08-08 11:56:17 -04:00
Eugene Yurtsev
06da7da547 Apply pydantic v0.2 changes 2024-08-08 11:55:06 -04:00
Eugene Yurtsev
3d561c3e6d model_validate_json 2024-08-08 11:34:23 -04:00
Eugene Yurtsev
09f9d3e972 get_fields 2024-08-08 11:31:57 -04:00
Eugene Yurtsev
c1e6e7d020 get_fields 2024-08-08 11:31:54 -04:00
Eugene Yurtsev
6f79443ab5 get_fields 2024-08-08 11:31:41 -04:00
Eugene Yurtsev
58c4e1ef86 Add name to Runnable Generator 2024-08-08 11:20:08 -04:00
Eugene Yurtsev
de30b04f37 Fix type issue 2024-08-08 11:12:31 -04:00
Eugene Yurtsev
f7a455299e Add more types 2024-08-07 22:12:29 -04:00
Eugene Yurtsev
76043abd47 add unit test 2024-08-07 17:23:54 -04:00
Eugene Yurtsev
61cdb9ccce Fix output schema 2024-08-07 17:23:44 -04:00
Eugene Yurtsev
1ef3fa54fc Fix types 2024-08-07 17:16:28 -04:00
Eugene Yurtsev
3856e3b02a Fix types 2024-08-07 17:14:59 -04:00
Eugene Yurtsev
035f09f20d Fix types 2024-08-07 17:14:47 -04:00
Eugene Yurtsev
63f7a5ab68 Replace __fields__ with model_fields 2024-08-07 17:11:14 -04:00
Eugene Yurtsev
8447d9f6f1 Replace __fields__ with model_fields 2024-08-07 17:10:07 -04:00
Eugene Yurtsev
95db9e9258 Replace __fields__ with model_fields 2024-08-07 17:07:15 -04:00
Eugene Yurtsev
0d1b93774b Resolve linting 2024-08-07 17:06:46 -04:00
Eugene Yurtsev
bcce3a2865 Resolve linting 2024-08-07 17:03:30 -04:00
Eugene Yurtsev
4a478d82bd Resolve linting 2024-08-07 17:02:15 -04:00
Eugene Yurtsev
a0c3657442 Resolve linting 2024-08-07 17:01:29 -04:00
Eugene Yurtsev
72c5c28b4d Resolve linting 2024-08-07 17:01:02 -04:00
Eugene Yurtsev
fe6f2f724b Use model_fields 2024-08-07 16:08:30 -04:00
Eugene Yurtsev
88d347e90c Remove pydantic lint in core 2024-08-07 16:06:40 -04:00
Eugene Yurtsev
741b50d4fd Fix serializer 2024-08-07 15:57:57 -04:00
Eugene Yurtsev
24c6825345 Fix serialization test 2024-08-07 15:57:41 -04:00
Eugene Yurtsev
32824aa55c Handle lint 2024-08-07 15:47:56 -04:00
Eugene Yurtsev
f6924653ea Handle lint 2024-08-07 15:47:51 -04:00
Eugene Yurtsev
66e8594b89 Handle lint 2024-08-07 15:46:37 -04:00
Eugene Yurtsev
3b9f061eac Handle lint 2024-08-07 15:45:21 -04:00
Eugene Yurtsev
76b6ee290d Replace __fields__ with model_fields 2024-08-07 15:44:31 -04:00
Eugene Yurtsev
22957311fe Add more tests for serializable 2024-08-07 15:40:59 -04:00
Eugene Yurtsev
f9df75c8cc Add more tests for serializable 2024-08-07 15:37:21 -04:00
Eugene Yurtsev
ece0ab8539 Add more tests for serializable 2024-08-07 15:18:16 -04:00
Eugene Yurtsev
4ddd9e5f23 lint 2024-08-07 15:05:52 -04:00
Eugene Yurtsev
f8e95e5735 lint 2024-08-07 15:04:02 -04:00
Eugene Yurtsev
6515b2f77b Linting fixes 2024-08-07 15:03:31 -04:00
Eugene Yurtsev
63fde4f095 Linting fixes 2024-08-07 13:59:22 -04:00
Eugene Yurtsev
d9bb9125c1 Linting fixes 2024-08-07 13:55:56 -04:00
Eugene Yurtsev
384d9f59a3 Linting fixes 2024-08-07 13:55:38 -04:00
Eugene Yurtsev
fc0fa7e8f0 Add missing import 2024-08-07 13:50:14 -04:00
Eugene Yurtsev
a1054d06ca Add missing import 2024-08-07 13:48:43 -04:00
Eugene Yurtsev
c2570a7a7c lint 2024-08-07 13:47:43 -04:00
Eugene Yurtsev
97f4128bfd Add missing imports 2024-08-07 13:47:26 -04:00
Eugene Yurtsev
2434dc8f92 update snapshots 2024-08-07 13:46:18 -04:00
Eugene Yurtsev
123d61a888 Add missing imports 2024-08-07 13:43:44 -04:00
Eugene Yurtsev
53f6f4a0c0 Mark explicitly with # pydantic: ignore 2024-08-07 13:41:17 -04:00
Eugene Yurtsev
550bef230a Merge branch 'master' into eugene/merge_pydantic_3_changes 2024-08-07 13:28:46 -04:00
Eugene Yurtsev
5a998d36b2 Convert to v1 model for now 2024-08-07 12:09:42 -04:00
Eugene Yurtsev
72cd199efc Fix create_subset_model_v1 2024-08-07 11:58:10 -04:00
Eugene Yurtsev
a1d993deb1 Remove deprecated comment 2024-08-07 11:54:21 -04:00
Eugene Yurtsev
e546e21d53 Update unit test for pydantic 2 2024-08-07 11:52:28 -04:00
Eugene Yurtsev
26d6426156 Fix extra space in repr 2024-08-07 11:48:11 -04:00
Eugene Yurtsev
8dffedebd6 Add Skip Validation() 2024-08-07 11:38:28 -04:00
Eugene Yurtsev
60adf8d6e4 Handle is_injected_arg_type 2024-08-07 11:36:56 -04:00
Eugene Yurtsev
d13a1ad5f5 Use _AnyIDDocument 2024-08-07 11:27:35 -04:00
Eugene Yurtsev
1e5f8a494a Add SkipValidation() 2024-08-07 11:25:21 -04:00
Eugene Yurtsev
5216131769 Fixed something? 2024-08-07 11:18:52 -04:00
Eugene Yurtsev
8bdaf858b8 Use is_basemodel_instance 2024-08-07 11:03:53 -04:00
Eugene Yurtsev
c37a0ca672 Use is_basemodel_subclass 2024-08-07 11:03:35 -04:00
Eugene Yurtsev
266cd15511 ADd Skip Validation 2024-08-07 10:51:43 -04:00
Eugene Yurtsev
9debf8144e ADd Skip Validation 2024-08-07 10:51:02 -04:00
Eugene Yurtsev
78ce0ed337 Fix broken type 2024-08-07 10:23:55 -04:00
Eugene Yurtsev
4aa1932bea update 2024-08-07 09:52:33 -04:00
Eugene Yurtsev
b658295b97 update 2024-08-07 09:40:29 -04:00
Eugene Yurtsev
8c59b6a026 Merge fix 2024-08-07 09:32:25 -04:00
Eugene Yurtsev
e35b43a7a7 Fix ConfigDict to be populate by name 2024-08-07 09:16:23 -04:00
Eugene Yurtsev
7288d914a8 Add missing model rebuild and optional 2024-08-07 09:14:06 -04:00
Eugene Yurtsev
1b487e261a add missing pydantic import 2024-08-07 09:04:36 -04:00
Eugene Yurtsev
3934663db9 Merge branch 'master' into eugene/merge_pydantic_3_changes 2024-08-07 08:59:28 -04:00
Eugene Yurtsev
fb639cb49c lint 2024-08-06 22:02:31 -04:00
Eugene Yurtsev
1856387e9e Add missing imports load and dumpd 2024-08-06 17:10:42 -04:00
Eugene Yurtsev
a5ad775a90 Add Optional import 2024-08-06 17:10:18 -04:00
Eugene Yurtsev
a321401683 Update pydantic utility 2024-08-06 16:54:55 -04:00
Eugene Yurtsev
8839220a00 Restore more missing stuff 2024-08-06 16:10:59 -04:00
Eugene Yurtsev
e6b2ca4da3 x 2024-08-06 16:08:06 -04:00
Eugene Yurtsev
d0c52d1dec x 2024-08-06 16:06:44 -04:00
Eugene Yurtsev
a5fa6d1c43 x 2024-08-06 16:05:43 -04:00
Eugene Yurtsev
7f79bd6e04 x 2024-08-06 16:04:14 -04:00
Eugene Yurtsev
339985e39e merge more 2024-08-06 15:59:53 -04:00
Eugene Yurtsev
f4ecd749d5 x 2024-08-06 15:58:55 -04:00
Eugene Yurtsev
cb61c6b4bf Merge branch 'master' into eugene/merge_pydantic_3_changes 2024-08-06 15:57:37 -04:00
Eugene Yurtsev
b42c2c6cd6 Update to master 2024-08-06 15:57:35 -04:00
Eugene Yurtsev
da6633bf0d update 2024-08-06 13:08:53 -04:00
Eugene Yurtsev
0193d18bec update 2024-08-06 13:04:17 -04:00
Eugene Yurtsev
0a82192e36 update forward refs 2024-08-06 12:41:52 -04:00
Eugene Yurtsev
202f6fef95 update 2024-08-06 12:39:00 -04:00
Eugene Yurtsev
c49416e908 fix typo 2024-08-06 12:35:05 -04:00
Eugene Yurtsev
ec93ea6240 update 2024-08-06 12:33:43 -04:00
Eugene Yurtsev
add20dc9a8 update 2024-08-06 12:30:33 -04:00
Eugene Yurtsev
7799474746 MANUAL: May need to revert 2024-08-06 11:47:27 -04:00
Eugene Yurtsev
d98c1f115f update 2024-08-06 11:46:39 -04:00
Eugene Yurtsev
d97f70def4 Update 2024-08-06 11:43:25 -04:00
Eugene Yurtsev
609c6b0963 Update 2024-08-06 11:40:43 -04:00
136 changed files with 15953 additions and 11525 deletions

View File

@@ -13,7 +13,7 @@ tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -vv tests/unit_tests
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
test_profile:
poetry run pytest -vv tests/unit_tests/ --profile-svg
@@ -39,17 +39,14 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff check .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff check --select I --fix $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml

View File

@@ -1,3 +1,13 @@
"""``langchain-core`` defines the base abstractions for the LangChain ecosystem.
The interfaces for core components like chat models, LLMs, vector stores, retrievers,
and more are defined here. The universal invocation protocol (Runnables) along with
a syntax for combining components (LangChain Expression Language) are also defined here.
No third-party integrations are defined here. The dependencies are kept purposefully
very lightweight.
"""
from importlib import metadata
from langchain_core._api import (

View File

@@ -270,7 +270,7 @@ def warn_beta(
message += f" {addendum}"
warning = LangChainBetaWarning(message)
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=2)
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=4)
def surface_langchain_beta_warnings() -> None:

View File

@@ -14,7 +14,17 @@ import contextlib
import functools
import inspect
import warnings
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
from typing import (
Any,
Callable,
Generator,
Type,
TypeVar,
Union,
cast,
)
from typing_extensions import ParamSpec
from langchain_core._api.internal import is_caller_internal
@@ -30,7 +40,8 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
# PUBLIC API
T = TypeVar("T", bound=Union[Type, Callable[..., Any]])
# Last Any should be FieldInfoV1 but this leads to circular imports
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
def _validate_deprecation_params(
@@ -133,6 +144,7 @@ def deprecated(
_package: str = package,
) -> T:
"""Implementation of the decorator returned by `deprecated`."""
from langchain_core.utils.pydantic import FieldInfoV1
def emit_warning() -> None:
"""Emit the warning."""
@@ -207,50 +219,73 @@ def deprecated(
)
return cast(T, obj)
elif isinstance(obj, FieldInfoV1):
wrapped = None
if not _obj_type:
_obj_type = "attribute"
if not _name:
raise ValueError(f"Field {obj} must have a name to be deprecated.")
old_doc = obj.description
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
return cast(
T,
FieldInfoV1(
default=obj.default,
default_factory=obj.default_factory,
description=new_doc,
alias=obj.alias,
exclude=obj.exclude,
),
)
elif isinstance(obj, property):
if not _obj_type:
_obj_type = "attribute"
wrapped = None
_name = _name or obj.fget.__qualname__
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
old_doc = obj.__doc__
class _deprecated_property(property):
"""A deprecated property."""
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
super().__init__(fget, fset, fdel, doc)
self.__orig_fget = fget
self.__orig_fset = fset
self.__orig_fdel = fdel
def __get__(self, instance, owner=None):
def __get__(self, instance, owner=None): # type: ignore[no-untyped-def]
if instance is not None or owner is not None:
emit_warning()
return self.fget(instance)
def __set__(self, instance, value):
def __set__(self, instance, value): # type: ignore[no-untyped-def]
if instance is not None:
emit_warning()
return self.fset(instance, value)
def __delete__(self, instance):
def __delete__(self, instance): # type: ignore[no-untyped-def]
if instance is not None:
emit_warning()
return self.fdel(instance)
def __set_name__(self, owner, set_name):
def __set_name__(self, owner, set_name): # type: ignore[no-untyped-def]
nonlocal _name
if _name == "<lambda>":
_name = set_name
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the property."""
return _deprecated_property(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
return cast(
T,
_deprecated_property(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
),
)
else:
_name = _name or obj.__qualname__
_name = _name or cast(Union[Type, Callable], obj).__qualname__
if not _obj_type:
# edge case: when a function is within another function
# within a test, this will call it a "method" not a "function"
@@ -409,7 +444,7 @@ def warn_deprecated(
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
)
warning = warning_cls(message)
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=4)
def surface_langchain_deprecation_warnings() -> None:
@@ -423,3 +458,51 @@ def surface_langchain_deprecation_warnings() -> None:
"default",
category=LangChainDeprecationWarning,
)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def rename_parameter(
*,
since: str,
removal: str,
old: str,
new: str,
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Decorator indicating that parameter *old* of *func* is renamed to *new*.
The actual implementation of *func* should use *new*, not *old*. If *old*
is passed to *func*, a DeprecationWarning is emitted, and its value is
used, even if *new* is also passed by keyword.
Example:
.. code-block:: python
@_api.rename_parameter("3.1", "bad_name", "good_name")
def func(good_name): ...
"""
def decorator(f: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if new in kwargs and old in kwargs:
raise TypeError(
f"{f.__name__}() got multiple values for argument {new!r}"
)
if old in kwargs:
warn_deprecated(
since,
removal=removal,
message=f"The parameter `{old}` of `{f.__name__}` was "
f"deprecated in {since} and will be removed "
f"in {removal} Use `{new}` instead.",
)
kwargs[new] = kwargs.pop(old)
return f(*args, **kwargs)
return wrapper
return decorator

View File

@@ -18,6 +18,8 @@ from typing import (
Union,
)
from pydantic import ConfigDict
from langchain_core._api.beta_decorator import beta
from langchain_core.runnables.base import (
Runnable,
@@ -229,8 +231,9 @@ class ContextSet(RunnableSerializable):
keys: Mapping[str, Optional[Runnable]]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def __init__(
self,

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
from uuid import UUID
@@ -13,6 +14,8 @@ if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
_LOGGER = logging.getLogger(__name__)
class RetrieverManagerMixin:
"""Mixin for Retriever callbacks."""
@@ -911,15 +914,67 @@ class BaseCallbackManager(CallbackManagerMixin):
def copy(self: T) -> T:
"""Copy the callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
handlers=self.handlers.copy(),
inheritable_handlers=self.inheritable_handlers.copy(),
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
tags=self.tags.copy(),
inheritable_tags=self.inheritable_tags.copy(),
metadata=self.metadata.copy(),
inheritable_metadata=self.inheritable_metadata.copy(),
)
def merge(self: T, other: BaseCallbackManager) -> T:
"""Merge the callback manager with another callback manager.
May be overwritten in subclasses. Primarily used internally
within merge_configs.
Returns:
BaseCallbackManager: The merged callback manager of the same type
as the current object.
Example: Merging two callback managers.
.. code-block:: python
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
from langchain_core.callbacks.stdout import StdOutCallbackHandler
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
with trace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
merged_manager = group_manager.merge(manager)
print(merged_manager.handlers)
# [
# <langchain_core.callbacks.stdout.StdOutCallbackHandler object at ...>,
# <langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler object at ...>,
# ]
print(merged_manager.tags)
# ['tag2', 'tag1']
""" # noqa: E501
manager = self.__class__(
parent_run_id=self.parent_run_id or other.parent_run_id,
handlers=[],
inheritable_handlers=[],
tags=list(set(self.tags + other.tags)),
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
metadata={
**self.metadata,
**other.metadata,
},
)
handlers = self.handlers + other.handlers
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
for handler in handlers:
manager.add_handler(handler)
for handler in inheritable_handlers:
manager.add_handler(handler, inherit=True)
return manager
@property
def is_async(self) -> bool:
"""Whether the callback manager is async."""

View File

@@ -1612,16 +1612,75 @@ class CallbackManagerForChainGroup(CallbackManager):
def copy(self) -> CallbackManagerForChainGroup:
"""Copy the callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
handlers=self.handlers.copy(),
inheritable_handlers=self.inheritable_handlers.copy(),
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
tags=self.tags.copy(),
inheritable_tags=self.inheritable_tags.copy(),
metadata=self.metadata.copy(),
inheritable_metadata=self.inheritable_metadata.copy(),
parent_run_manager=self.parent_run_manager,
)
def merge(
self: CallbackManagerForChainGroup, other: BaseCallbackManager
) -> CallbackManagerForChainGroup:
"""Merge the group callback manager with another callback manager.
Overwrites the merge method in the base class to ensure that the
parent run manager is preserved. Keeps the parent_run_manager
from the current object.
Returns:
CallbackManagerForChainGroup: A copy of the current object with the
handlers, tags, and other attributes merged from the other object.
Example: Merging two callback managers.
.. code-block:: python
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
from langchain_core.callbacks.stdout import StdOutCallbackHandler
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
with trace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
merged_manager = group_manager.merge(manager)
print(type(merged_manager))
# <class 'langchain_core.callbacks.manager.CallbackManagerForChainGroup'>
print(merged_manager.handlers)
# [
# <langchain_core.callbacks.stdout.LangChainTracer object at ...>,
# <langchain_core.callbacks.streaming_stdout.StdOutCallbackHandler object at ...>,
# ]
print(merged_manager.tags)
# ['tag2', 'tag1']
""" # noqa: E501
manager = self.__class__(
parent_run_id=self.parent_run_id or other.parent_run_id,
handlers=[],
inheritable_handlers=[],
tags=list(set(self.tags + other.tags)),
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
metadata={
**self.metadata,
**other.metadata,
},
parent_run_manager=self.parent_run_manager,
)
handlers = self.handlers + other.handlers
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
for handler in handlers:
manager.add_handler(handler)
for handler in inheritable_handlers:
manager.add_handler(handler, inherit=True)
return manager
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when traced chain group ends.
@@ -2040,16 +2099,75 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
def copy(self) -> AsyncCallbackManagerForChainGroup:
"""Copy the async callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
handlers=self.handlers.copy(),
inheritable_handlers=self.inheritable_handlers.copy(),
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
tags=self.tags.copy(),
inheritable_tags=self.inheritable_tags.copy(),
metadata=self.metadata.copy(),
inheritable_metadata=self.inheritable_metadata.copy(),
parent_run_manager=self.parent_run_manager,
)
def merge(
self: AsyncCallbackManagerForChainGroup, other: BaseCallbackManager
) -> AsyncCallbackManagerForChainGroup:
"""Merge the group callback manager with another callback manager.
Overwrites the merge method in the base class to ensure that the
parent run manager is preserved. Keeps the parent_run_manager
from the current object.
Returns:
AsyncCallbackManagerForChainGroup: A copy of the current AsyncCallbackManagerForChainGroup
with the handlers, tags, etc. of the other callback manager merged in.
Example: Merging two callback managers.
.. code-block:: python
from langchain_core.callbacks.manager import CallbackManager, atrace_as_chain_group
from langchain_core.callbacks.stdout import StdOutCallbackHandler
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
async with atrace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
merged_manager = group_manager.merge(manager)
print(type(merged_manager))
# <class 'langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup'>
print(merged_manager.handlers)
# [
# <langchain_core.callbacks.stdout.LangChainTracer object at ...>,
# <langchain_core.callbacks.streaming_stdout.StdOutCallbackHandler object at ...>,
# ]
print(merged_manager.tags)
# ['tag2', 'tag1']
""" # noqa: E501
manager = self.__class__(
parent_run_id=self.parent_run_id or other.parent_run_id,
handlers=[],
inheritable_handlers=[],
tags=list(set(self.tags + other.tags)),
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
metadata={
**self.metadata,
**other.metadata,
},
parent_run_manager=self.parent_run_manager,
)
handlers = self.handlers + other.handlers
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
for handler in handlers:
manager.add_handler(handler)
for handler in inheritable_handlers:
manager.add_handler(handler, inherit=True)
return manager
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None:

View File

@@ -20,13 +20,14 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Sequence, Union
from pydantic import BaseModel, Field
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
get_buffer_string,
)
from langchain_core.pydantic_v1 import BaseModel, Field
class BaseChatMessageHistory(ABC):

View File

@@ -1,5 +1,6 @@
from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader
from langchain_core.document_loaders.blob_loaders import Blob, BlobLoader, PathLike
from langchain_core.document_loaders.langsmith import LangSmithLoader
__all__ = [
"BaseBlobParser",
@@ -7,4 +8,5 @@ __all__ = [
"Blob",
"BlobLoader",
"PathLike",
"LangSmithLoader",
]

View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
from langchain_core.documents.base import Blob
class BaseLoader(ABC):
class BaseLoader(ABC): # noqa: B024
"""Interface for Document Loader.
Implementations should implement the lazy-loading method using generators

View File

@@ -4,10 +4,12 @@ import contextlib
import mimetypes
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Generator, List, Literal, Mapping, Optional, Union, cast
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, model_validator
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils.pydantic import v1_repr
PathLike = Union[str, PurePath]
@@ -110,9 +112,10 @@ class Blob(BaseMedia):
path: Optional[PathLike] = None
"""Location where the original content was found."""
class Config:
arbitrary_types_allowed = True
frozen = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
frozen=True,
)
@property
def source(self) -> Optional[str]:
@@ -127,8 +130,9 @@ class Blob(BaseMedia):
return cast(Optional[str], self.metadata["source"])
return str(self.path) if self.path else None
@root_validator(pre=True)
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
@model_validator(mode="before")
@classmethod
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided")
@@ -137,7 +141,7 @@ class Blob(BaseMedia):
def as_string(self) -> str:
"""Read data as a string."""
if self.data is None and self.path:
with open(str(self.path), "r", encoding=self.encoding) as f:
with open(str(self.path), encoding=self.encoding) as f:
return f.read()
elif isinstance(self.data, bytes):
return self.data.decode(self.encoding)
@@ -293,3 +297,7 @@ class Document(BaseMedia):
return f"page_content='{self.page_content}' metadata={self.metadata}"
else:
return f"page_content='{self.page_content}'"
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Sequence
from pydantic import BaseModel
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import run_in_executor

View File

@@ -4,8 +4,9 @@
import hashlib
from typing import List
from pydantic import BaseModel
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
class FakeEmbeddings(Embeddings, BaseModel):
@@ -15,14 +16,36 @@ class FakeEmbeddings(Embeddings, BaseModel):
Do not use this outside of testing, as it is not a real embedding model.
Example:
Instantiate:
.. code-block:: python
from langchain_core.embeddings import FakeEmbeddings
embed = FakeEmbeddings(size=100)
fake_embeddings = FakeEmbeddings(size=100)
fake_embeddings.embed_documents(["hello world", "foo bar"])
Embed single text:
.. code-block:: python
input_text = "The meaning of life is 42"
vector = embed.embed_query(input_text)
print(vector[:3])
.. code-block:: python
[-0.700234640213188, -0.581266257710429, -1.1328482266445354]
Embed multiple texts:
.. code-block:: python
input_texts = ["Document 1...", "Document 2..."]
vectors = embed.embed_documents(input_texts)
print(len(vectors))
# The first 3 coordinates for the first vector
print(vectors[0][:3])
.. code-block:: python
2
[-0.5670477847544458, -0.31403828652395727, -0.5840547508955257]
"""
size: int
@@ -48,14 +71,36 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
Do not use this outside of testing, as it is not a real embedding model.
Example:
Instantiate:
.. code-block:: python
from langchain_core.embeddings import DeterministicFakeEmbedding
embed = DeterministicFakeEmbedding(size=100)
fake_embeddings = DeterministicFakeEmbedding(size=100)
fake_embeddings.embed_documents(["hello world", "foo bar"])
Embed single text:
.. code-block:: python
input_text = "The meaning of life is 42"
vector = embed.embed_query(input_text)
print(vector[:3])
.. code-block:: python
[-0.700234640213188, -0.581266257710429, -1.1328482266445354]
Embed multiple texts:
.. code-block:: python
input_texts = ["Document 1...", "Document 2..."]
vectors = embed.embed_documents(input_texts)
print(len(vectors))
# The first 3 coordinates for the first vector
print(vectors[0][:3])
.. code-block:: python
2
[-0.5670477847544458, -0.31403828652395727, -0.5840547508955257]
"""
size: int

View File

@@ -3,9 +3,10 @@
import re
from typing import Callable, Dict, List
from pydantic import BaseModel, validator
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, validator
def _get_length_based(text: str) -> int:

View File

@@ -5,9 +5,10 @@ from __future__ import annotations
from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from pydantic import BaseModel, ConfigDict, Extra
from langchain_core.documents import Document
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
@@ -42,9 +43,10 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
vectorstore_kwargs: Optional[Dict[str, Any]] = None
"""Extra arguments passed to similarity_search function of the vectorstore."""
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra=Extra.forbid,
)
@staticmethod
def _example_to_text(

View File

@@ -39,7 +39,7 @@ class OutputParserException(ValueError, LangChainException):
llm_output: Optional[str] = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
super().__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(

View File

@@ -12,6 +12,8 @@ from typing import (
Optional,
)
from pydantic import Field
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
@@ -20,7 +22,6 @@ from langchain_core.callbacks import (
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_core.load import Serializable
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
@@ -32,15 +33,17 @@ def _has_next(iterator: Iterator) -> bool:
return next(iterator, sentinel) is not sentinel
@beta()
class Node(Serializable):
"""Node in the GraphVectorStore.
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
For instance two nodes `a` and `b` connected over a hyperlink `https://some-url`
For instance two nodes `a` and `b` connected over a hyperlink ``https://some-url``
would look like:
.. code-block:: python
[
Node(
id="a",
@@ -79,12 +82,12 @@ def _texts_to_nodes(
for text in texts:
try:
_metadata = next(metadatas_it).copy() if metadatas_it else {}
except StopIteration:
raise ValueError("texts iterable longer than metadatas")
except StopIteration as e:
raise ValueError("texts iterable longer than metadatas") from e
try:
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("texts iterable longer than ids")
except StopIteration as e:
raise ValueError("texts iterable longer than ids") from e
links = _metadata.pop(METADATA_LINKS_KEY, [])
if not isinstance(links, list):
@@ -115,7 +118,15 @@ def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
)
@beta()
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
"""Convert nodes to documents.
Args:
nodes: The nodes to convert to documents.
Returns:
The documents generated from the nodes.
"""
for node in nodes:
metadata = node.metadata.copy()
metadata[METADATA_LINKS_KEY] = [
@@ -588,23 +599,28 @@ class GraphVectorStore(VectorStore):
"'mmr' or 'traversal'."
)
def as_retriever(self, **kwargs: Any) -> "GraphVectorStoreRetriever":
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
Args:
search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be "traversal" (default), "similarity", "mmr", or
"similarity_score_threshold".
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
k: Amount of documents to return (Default: 4)
depth: The maximum depth of edges to traverse (Default: 1)
score_threshold: Minimum relevance threshold
for similarity_score_threshold
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
lambda_mult: Diversity of results returned by MMR;
1 for minimum diversity and 0 for maximum. (Default: 0.5)
**kwargs: Keyword arguments to pass to the search function.
Can include:
- search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
``similarity_score_threshold``.
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
- k(int): Amount of documents to return (Default: 4).
- depth(int): The maximum depth of edges to traverse (Default: 1).
- score_threshold(float): Minimum relevance threshold
for similarity_score_threshold.
- fetch_k(int): Amount of documents to pass to MMR algorithm
(Default: 20).
- lambda_mult(float): Diversity of results returned by MMR;
1 for minimum diversity and 0 for maximum. (Default: 0.5).
Returns:
Retriever for this GraphVectorStore.

View File

@@ -1,9 +1,11 @@
from dataclasses import dataclass
from typing import Iterable, List, Literal, Union
from langchain_core._api import beta
from langchain_core.documents import Document
@beta()
@dataclass(frozen=True)
class Link:
"""A link to/from a tag of a given tag.
@@ -38,8 +40,10 @@ class Link:
METADATA_LINKS_KEY = "links"
@beta()
def get_links(doc: Document) -> List[Link]:
"""Get the links from a document.
Args:
doc: The document to get the link tags from.
Returns:
@@ -54,8 +58,10 @@ def get_links(doc: Document) -> List[Link]:
return links
@beta()
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
"""Add links to the given metadata.
Args:
doc: The document to add the links to.
*links: The links to add to the document.
@@ -68,6 +74,7 @@ def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
links_in_metadata.append(link)
@beta()
def copy_with_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> Document:
"""Return a document with the given links added.

View File

@@ -25,10 +25,11 @@ from typing import (
cast,
)
from pydantic import model_validator
from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from langchain_core.indexing.base import DocumentIndex, RecordManager
from langchain_core.pydantic_v1 import root_validator
from langchain_core.vectorstores import VectorStore
# Magic UUID to use as a namespace for hashing.
@@ -68,8 +69,9 @@ class _HashedDocument(Document):
def is_lc_serializable(cls) -> bool:
return False
@root_validator(pre=True)
def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
"""Root validator to calculate content and metadata hash."""
content = values.get("page_content", "")
metadata = values.get("metadata", {})
@@ -91,7 +93,7 @@ class _HashedDocument(Document):
raise ValueError(
f"Failed to hash metadata: {e}. "
f"Please use a dict that can be serialized using json."
)
) from e
values["content_hash"] = content_hash
values["metadata_hash"] = metadata_hash
@@ -435,7 +437,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
async def aindex(
docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]],
record_manager: RecordManager,
vectorstore: Union[VectorStore, DocumentIndex],
vector_store: Union[VectorStore, DocumentIndex],
*,
batch_size: int = 100,
cleanup: Literal["incremental", "full", None] = None,
@@ -506,7 +508,7 @@ async def aindex(
if cleanup == "incremental" and source_id_key is None:
raise ValueError("Source id key is required when cleanup mode is incremental.")
destination = vectorstore # Renaming internally for clarity
destination = vector_store # Renaming internally for clarity
# If it's a vectorstore, let's check if it has the required methods.
if isinstance(destination, VectorStore):

View File

@@ -1,12 +1,13 @@
import uuid
from typing import Any, Dict, List, Optional, Sequence, cast
from pydantic import Field
from langchain_core._api import beta
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.indexing import UpsertResponse
from langchain_core.indexing.base import DeleteResponse, DocumentIndex
from langchain_core.pydantic_v1 import Field
@beta(message="Introduced in version 0.2.29. Underlying abstraction subject to change.")

View File

@@ -39,6 +39,7 @@ https://python.langchain.com/v0.2/docs/how_to/custom_llm/
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
LanguageModelLike,
LanguageModelOutput,
@@ -62,6 +63,7 @@ __all__ = [
"LLM",
"LanguageModelInput",
"get_tokenizer",
"LangSmithParams",
"LanguageModelOutput",
"LanguageModelLike",
"FakeListLLM",

View File

@@ -8,6 +8,7 @@ from typing import (
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
@@ -17,7 +18,8 @@ from typing import (
Union,
)
from typing_extensions import TypeAlias
from pydantic import BaseModel, ConfigDict, Field, validator
from typing_extensions import TypeAlias, TypedDict
from langchain_core._api import deprecated
from langchain_core.messages import (
@@ -27,7 +29,6 @@ from langchain_core.messages import (
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
@@ -37,6 +38,23 @@ if TYPE_CHECKING:
from langchain_core.outputs import LLMResult
class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""
ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat", "llm"]
"""Type of the model. Should be 'chat' or 'llm'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""
@lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any:
"""Get a GPT-2 tokenizer instance.
@@ -46,12 +64,12 @@ def get_tokenizer() -> Any:
"""
try:
from transformers import GPT2TokenizerFast # type: ignore[import]
except ImportError:
except ImportError as e:
raise ImportError(
"Could not import transformers python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install transformers`."
)
) from e
# create a GPT-2 tokenizer instance
return GPT2TokenizerFast.from_pretrained("gpt2")
@@ -95,7 +113,11 @@ class BaseLanguageModel(
Caching is not currently supported for streaming methods of models.
"""
verbose: bool = Field(default_factory=_get_verbosity)
# Repr = False is consistent with pydantic 1 if verbose = False
# We can relax this for pydantic 2?
# TODO(Team): decide what to do here.
# Modified just to get unit tests to pass.
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
@@ -108,6 +130,10 @@ class BaseLanguageModel(
)
"""Optional encoder to use for counting tokens."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@validator("verbose", pre=True, always=True, allow_reuse=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
@@ -220,7 +246,7 @@ class BaseLanguageModel(
# generate responses that match a given schema.
raise NotImplementedError()
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
@@ -241,7 +267,7 @@ class BaseLanguageModel(
Top model prediction as a string.
"""
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
def predict_messages(
self,
@@ -266,7 +292,7 @@ class BaseLanguageModel(
Top model prediction as a message.
"""
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@abstractmethod
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
@@ -287,7 +313,7 @@ class BaseLanguageModel(
Top model prediction as a string.
"""
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@abstractmethod
async def apredict_messages(
self,

View File

@@ -23,7 +23,12 @@ from typing import (
cast,
)
from typing_extensions import TypedDict
from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
)
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
@@ -36,7 +41,11 @@ from langchain_core.callbacks import (
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
@@ -55,11 +64,6 @@ from langchain_core.outputs import (
RunInfo,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
root_validator,
)
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
@@ -73,23 +77,6 @@ if TYPE_CHECKING:
from langchain_core.tools import BaseTool
class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""
ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat"]
"""Type of the model. Should be 'chat'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream.
@@ -208,14 +195,40 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
""" # noqa: E501
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Callback manager to add to the run trace."""
# TODO(0.3): Figure out how to re-apply deprecated decorator
# callback_manager: Optional[BaseCallbackManager] = deprecated(
# name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
# )(
# Field(
# default=None,
# exclude=True,
# description="Callback manager to add to the run trace.",
# )
# )
callback_manager: Optional[BaseCallbackManager] = Field(
default=None,
exclude=True,
description="Callback manager to add to the run trace.",
)
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
"""An optional rate limiter to use for limiting the number of requests."""
"An optional rate limiter to use for limiting the number of requests."
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
disable_streaming: Union[bool, Literal["tool_calling"]] = False
"""Whether to disable streaming for this model.
If streaming is bypassed, then ``stream()/astream()`` will defer to
``invoke()/ainvoke()``.
- If True, will always bypass streaming case.
- If "tool_calling", will bypass streaming case only when the model is called
with a ``tools`` keyword argument.
- If False (default), will always use streaming case if available.
"""
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
"""Raise deprecation warning if callback_manager is used.
Args:
@@ -231,12 +244,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
stacklevel=5,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
# --- Runnable methods ---
@@ -302,6 +317,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
return cast(ChatGeneration, llm_result.generations[0][0]).message
def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
**kwargs: Any,
) -> bool:
"""Determine if a given model call should hit the streaming API."""
sync_not_implemented = type(self)._stream == BaseChatModel._stream
async_not_implemented = type(self)._astream == BaseChatModel._astream
# Check if streaming is implemented.
if (not async_api) and sync_not_implemented:
return False
# Note, since async falls back to sync we check both here.
if async_api and async_not_implemented and sync_not_implemented:
return False
# Check if streaming has been disabled on this instance.
if self.disable_streaming is True:
return False
# We assume tools are passed in via "tools" kwarg in all models.
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
return False
# Check if a runtime streaming flag has been passed in.
if "stream" in kwargs:
return kwargs["stream"]
# Check if any streaming callback handlers have been passed in.
handlers = run_manager.handlers if run_manager else []
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
def stream(
self,
input: LanguageModelInput,
@@ -310,7 +360,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
@@ -380,10 +430,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if (
type(self)._astream is BaseChatModel._astream
and type(self)._stream is BaseChatModel._stream
):
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
# No async or sync stream is implemented, so fall back to ainvoke
yield cast(
BaseMessageChunk,
@@ -471,9 +518,37 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
ls_params = LangSmithParams(ls_model_type="chat")
# get default provider from class name
default_provider = self.__class__.__name__
if default_provider.startswith("Chat"):
default_provider = default_provider[4:].lower()
elif default_provider.endswith("Chat"):
default_provider = default_provider[:-4]
default_provider = default_provider.lower()
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
if stop:
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name
# temperature
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
ls_params["ls_temperature"] = kwargs["temperature"]
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
ls_params["ls_temperature"] = self.temperature
# max_tokens
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
ls_params["ls_max_tokens"] = self.max_tokens
return ls_params
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
@@ -760,20 +835,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _stream not implemented
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
"stream",
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
False,
)
if run_manager
else False
),
if self._should_stream(
async_api=False,
run_manager=run_manager,
**kwargs,
):
chunks: List[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs):
@@ -847,23 +912,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _astream not implemented
if (
type(self)._astream != BaseChatModel._astream
or type(self)._stream != BaseChatModel._stream
) and kwargs.pop(
"stream",
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
False,
)
if run_manager
else False
),
if self._should_stream(
async_api=True,
run_manager=run_manager,
**kwargs,
):
chunks: List[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs):
@@ -963,7 +1015,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
break
yield item # type: ignore[misc]
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def __call__(
self,
messages: List[BaseMessage],
@@ -995,13 +1047,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
raise ValueError("Unexpected generation type")
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def call_as_llm(
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@@ -1015,7 +1067,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
raise ValueError("Cannot use predict when output is not a string.")
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict_messages(
self,
messages: List[BaseMessage],
@@ -1029,7 +1081,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
_stop = list(stop)
return self(messages, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@@ -1045,7 +1097,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
raise ValueError("Cannot use predict when output is not a string.")
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict_messages(
self,
messages: List[BaseMessage],

View File

@@ -69,6 +69,10 @@ class FakeListLLM(LLM):
return {"responses": self.responses}
class FakeListLLMError(Exception):
"""Fake error for testing purposes."""
class FakeStreamingListLLM(FakeListLLM):
"""Fake streaming list LLM for testing purposes.
@@ -98,7 +102,7 @@ class FakeStreamingListLLM(FakeListLLM):
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
raise FakeListLLMError
yield c
async def astream(
@@ -118,5 +122,5 @@ class FakeStreamingListLLM(FakeListLLM):
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
raise FakeListLLMError
yield c

View File

@@ -44,6 +44,10 @@ class FakeMessagesListChatModel(BaseChatModel):
return "fake-messages-list-chat-model"
class FakeListChatModelError(Exception):
pass
class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
@@ -93,7 +97,7 @@ class FakeListChatModel(SimpleChatModel):
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@@ -116,7 +120,7 @@ class FakeListChatModel(SimpleChatModel):
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property

View File

@@ -27,6 +27,7 @@ from typing import (
)
import yaml
from pydantic import ConfigDict, Field, model_validator
from tenacity import (
RetryCallState,
before_sleep_log,
@@ -48,7 +49,11 @@ from langchain_core.callbacks import (
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.load import dumpd
from langchain_core.messages import (
AIMessage,
@@ -58,7 +63,6 @@ from langchain_core.messages import (
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from langchain_core.runnables.config import run_in_executor
@@ -110,13 +114,12 @@ def create_base_retry_decorator(
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
retry_instance: retry_base = retry_if_exception_type(error_types[0])
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(error)
return retry(
@@ -297,16 +300,19 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED]"""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
stacklevel=5,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@@ -331,6 +337,43 @@ class BaseLLM(BaseLanguageModel[str], ABC):
"Must be a PromptValue, str, or list of BaseMessages."
)
def _get_ls_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
# get default provider from class name
default_provider = self.__class__.__name__
if default_provider.endswith("LLM"):
default_provider = default_provider[:-3]
default_provider = default_provider.lower()
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="llm")
if stop:
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name
# temperature
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
ls_params["ls_temperature"] = kwargs["temperature"]
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
ls_params["ls_temperature"] = self.temperature
# max_tokens
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
ls_params["ls_max_tokens"] = self.max_tokens
return ls_params
def invoke(
self,
input: LanguageModelInput,
@@ -487,13 +530,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
inheritable_metadata,
self.metadata,
)
(run_manager,) = callback_manager.on_llm_start(
@@ -548,13 +595,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
inheritable_metadata,
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
@@ -796,6 +847,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
f" argument of type {type(prompts)}."
)
# Create callback managers
if isinstance(metadata, list):
metadata = [
{
**(meta or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
for meta in metadata
]
elif isinstance(metadata, dict):
metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
else:
pass
if (
isinstance(callbacks, list)
and callbacks
@@ -1017,6 +1083,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
if isinstance(metadata, list):
metadata = [
{
**(meta or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
for meta in metadata
]
elif isinstance(metadata, dict):
metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
else:
pass
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))
@@ -1150,7 +1231,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def __call__(
self,
prompt: str,
@@ -1220,7 +1301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
return result.generations[0][0].text
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@@ -1230,7 +1311,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
_stop = list(stop)
return self(text, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict_messages(
self,
messages: List[BaseMessage],
@@ -1246,7 +1327,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content)
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@@ -1256,7 +1337,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
_stop = list(stop)
return await self._call_async(text, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict_messages(
self,
messages: List[BaseMessage],

View File

@@ -10,9 +10,10 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict # pydantic: ignore
from typing_extensions import NotRequired
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import v1_repr
class BaseSerialized(TypedDict):
@@ -80,7 +81,7 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
Exception: If the key is not in the model.
"""
try:
return model.__fields__[key].get_default() != value
return model.model_fields[key].get_default() != value
except Exception:
return True
@@ -161,16 +162,25 @@ class Serializable(BaseModel, ABC):
For example, for the class `langchain.llms.openai.OpenAI`, the id is
["langchain", "llms", "openai", "OpenAI"].
"""
return [*cls.get_lc_namespace(), cls.__name__]
# Pydantic generics change the class name. So we need to do the following
if (
"origin" in cls.__pydantic_generic_metadata__
and cls.__pydantic_generic_metadata__["origin"] is not None
):
original_name = cls.__pydantic_generic_metadata__["origin"].__name__
else:
original_name = cls.__name__
return [*cls.get_lc_namespace(), original_name]
class Config:
extra = "ignore"
model_config = ConfigDict(
extra="ignore",
)
def __repr_args__(self) -> Any:
return [
(k, v)
for k, v in super().__repr_args__()
if (k not in self.__fields__ or try_neq_default(v, k, self))
if (k not in self.model_fields or try_neq_default(v, k, self))
]
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
@@ -184,12 +194,15 @@ class Serializable(BaseModel, ABC):
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
and _is_field_useful(self, k, v)
}
lc_kwargs = {}
for k, v in self:
if not _is_field_useful(self, k, v):
continue
# Do nothing if the field is excluded
if k in self.model_fields and self.model_fields[k].exclude:
continue
lc_kwargs[k] = getattr(self, k, v)
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
@@ -221,8 +234,10 @@ class Serializable(BaseModel, ABC):
# that are not present in the fields.
for key in list(secrets):
value = secrets[key]
if key in this.__fields__:
secrets[this.__fields__[key].alias] = value
if key in this.model_fields:
alias = this.model_fields[key].alias
if alias is not None:
secrets[alias] = value
lc_kwargs.update(this.lc_attributes)
# include all secrets, even if not specified in kwargs
@@ -244,6 +259,10 @@ class Serializable(BaseModel, ABC):
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
"""Check if a field is useful as a constructor argument.
@@ -259,10 +278,46 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
If the field is not required and the value is None, it is useful if the
default value is different from the value.
"""
field = inst.__fields__.get(key)
field = inst.model_fields.get(key)
if not field:
return False
return field.required is True or value or field.get_default() != value
if field.is_required():
return True
# Handle edge case: a value cannot be converted to a boolean (e.g. a
# Pandas DataFrame).
try:
value_is_truthy = bool(value)
except Exception as _:
value_is_truthy = False
if value_is_truthy:
return True
# Value is still falsy here!
if field.default_factory is dict and isinstance(value, dict):
return False
# Value is still falsy here!
if field.default_factory is list and isinstance(value, list):
return False
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
# Pandas DataFrames).
try:
value_neq_default = bool(field.get_default() != value)
except Exception as _:
try:
value_neq_default = all(field.get_default() != value)
except Exception as _:
try:
value_neq_default = value is not field.default
except Exception as _:
value_neq_default = False
# If value is falsy and does not match the default
return value_is_truthy or value_neq_default
def _replace_secrets(

View File

@@ -13,6 +13,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import ConfigDict
from langchain_core.load.serializable import Serializable
from langchain_core.runnables import run_in_executor
@@ -47,8 +49,9 @@ class BaseMemory(Serializable, ABC):
pass
""" # noqa: E501
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
@abstractmethod

View File

@@ -1,7 +1,8 @@
import json
from typing import Any, Dict, List, Literal, Optional, Union
from typing_extensions import TypedDict
from pydantic import model_validator
from typing_extensions import Self, TypedDict
from langchain_core.messages.base import (
BaseMessage,
@@ -24,7 +25,6 @@ from langchain_core.messages.tool import (
from langchain_core.messages.tool import (
tool_call_chunk as create_tool_call_chunk,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json
@@ -111,8 +111,9 @@ class AIMessage(BaseMessage):
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator(pre=True)
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
@model_validator(mode="before")
@classmethod
def _backwards_compat_tool_calls(cls, values: dict) -> Any:
check_additional_kwargs = not any(
values.get(k)
for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
@@ -204,7 +205,7 @@ class AIMessage(BaseMessage):
return (base.strip() + "\n" + "\n".join(lines)).strip()
AIMessage.update_forward_refs()
AIMessage.model_rebuild()
class AIMessageChunk(AIMessage, BaseMessageChunk):
@@ -238,8 +239,8 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator(pre=False, skip_on_failure=True)
def init_tool_calls(cls, values: dict) -> dict:
@model_validator(mode="after")
def init_tool_calls(self) -> Self:
"""Initialize tool calls from tool call chunks.
Args:
@@ -251,35 +252,35 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
Raises:
ValueError: If the tool call chunks are malformed.
"""
if not values["tool_call_chunks"]:
if values["tool_calls"]:
values["tool_call_chunks"] = [
if not self.tool_call_chunks:
if self.tool_calls:
self.tool_call_chunks = [
create_tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
index=None,
)
for tc in values["tool_calls"]
for tc in self.tool_calls
]
if values["invalid_tool_calls"]:
tool_call_chunks = values.get("tool_call_chunks", [])
if self.invalid_tool_calls:
tool_call_chunks = self.tool_call_chunks
tool_call_chunks.extend(
[
create_tool_call_chunk(
name=tc["name"], args=tc["args"], id=tc["id"], index=None
)
for tc in values["invalid_tool_calls"]
for tc in self.invalid_tool_calls
]
)
values["tool_call_chunks"] = tool_call_chunks
self.tool_call_chunks = tool_call_chunks
return values
return self
tool_calls = []
invalid_tool_calls = []
for chunk in values["tool_call_chunks"]:
for chunk in self.tool_call_chunks:
try:
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
if isinstance(args_, dict):
tool_calls.append(
create_tool_call(
@@ -299,9 +300,9 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
error=None,
)
)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
return values
self.tool_calls = tool_calls
self.invalid_tool_calls = invalid_tool_calls
return self
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):

View File

@@ -2,11 +2,13 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
from pydantic import ConfigDict, Extra, Field
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.utils import get_bolded_text
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.interactive_env import is_interactive_env
from langchain_core.utils.pydantic import v1_repr
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
@@ -51,8 +53,9 @@ class BaseMessage(Serializable):
"""An optional unique identifier for the message. This should ideally be
provided by the provider/model which created the message."""
class Config:
extra = Extra.allow
model_config = ConfigDict(
extra=Extra.allow,
)
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
@@ -108,6 +111,10 @@ class BaseMessage(Serializable):
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],

View File

@@ -25,7 +25,7 @@ class ChatMessage(BaseMessage):
return ["langchain", "schema", "messages"]
ChatMessage.update_forward_refs()
ChatMessage.model_rebuild()
class ChatMessageChunk(ChatMessage, BaseMessageChunk):

View File

@@ -32,7 +32,7 @@ class FunctionMessage(BaseMessage):
return ["langchain", "schema", "messages"]
FunctionMessage.update_forward_refs()
FunctionMessage.model_rebuild()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):

View File

@@ -56,7 +56,7 @@ class HumanMessage(BaseMessage):
super().__init__(content=content, **kwargs)
HumanMessage.update_forward_refs()
HumanMessage.model_rebuild()
class HumanMessageChunk(HumanMessage, BaseMessageChunk):

View File

@@ -33,4 +33,4 @@ class RemoveMessage(BaseMessage):
return ["langchain", "schema", "messages"]
RemoveMessage.update_forward_refs()
RemoveMessage.model_rebuild()

View File

@@ -50,7 +50,7 @@ class SystemMessage(BaseMessage):
super().__init__(content=content, **kwargs)
SystemMessage.update_forward_refs()
SystemMessage.model_rebuild()
class SystemMessageChunk(SystemMessage, BaseMessageChunk):

View File

@@ -1,6 +1,7 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import Field
from typing_extensions import NotRequired, TypedDict
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
@@ -70,6 +71,11 @@ class ToolMessage(BaseMessage):
.. versionadded:: 0.2.24
"""
additional_kwargs: dict = Field(default_factory=dict, repr=False)
"""Currently inherited from BaseMessage, but not used."""
response_metadata: dict = Field(default_factory=dict, repr=False)
"""Currently inherited from BaseMessage, but not used."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.
@@ -88,7 +94,7 @@ class ToolMessage(BaseMessage):
super().__init__(content=content, **kwargs)
ToolMessage.update_forward_refs()
ToolMessage.model_rebuild()
class ToolMessageChunk(ToolMessage, BaseMessageChunk):

View File

@@ -10,6 +10,7 @@ Some examples of what you can do with these functions include:
from __future__ import annotations
import inspect
import json
from functools import partial
from typing import (
TYPE_CHECKING,
@@ -213,7 +214,23 @@ def _create_message_from_message_type(
if id is not None:
kwargs["id"] = id
if tool_calls is not None:
kwargs["tool_calls"] = tool_calls
kwargs["tool_calls"] = []
for tool_call in tool_calls:
# Convert OpenAI-format tool call to LangChain format.
if "function" in tool_call:
args = tool_call["function"]["arguments"]
if isinstance(args, str):
args = json.loads(args, strict=False)
kwargs["tool_calls"].append(
{
"name": tool_call["function"]["name"],
"args": args,
"id": tool_call["id"],
"type": "tool_call",
}
)
else:
kwargs["tool_calls"].append(tool_call)
if message_type in ("human", "user"):
message: BaseMessage = HumanMessage(content=content, **kwargs)
elif message_type in ("ai", "assistant"):
@@ -271,11 +288,12 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
msg_type = msg_kwargs.pop("role")
except KeyError:
msg_type = msg_kwargs.pop("type")
msg_content = msg_kwargs.pop("content")
except KeyError:
# None msg content is not allowed
msg_content = msg_kwargs.pop("content") or ""
except KeyError as e:
raise ValueError(
f"Message dict must contain 'role' and 'content' keys, got {message}"
)
) from e
_message = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
@@ -326,9 +344,7 @@ def _runnable_support(func: Callable) -> Callable:
if messages is not None:
return func(messages, **kwargs)
else:
return RunnableLambda(
partial(func, **kwargs), name=getattr(func, "__name__")
)
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
wrapped.__doc__ = func.__doc__
return wrapped
@@ -425,6 +441,8 @@ def filter_messages(
@_runnable_support
def merge_message_runs(
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
*,
chunk_separator: str = "\n",
) -> List[BaseMessage]:
"""Merge consecutive Messages of the same type.
@@ -433,13 +451,16 @@ def merge_message_runs(
Args:
messages: Sequence Message-like objects to merge.
chunk_separator: Specify the string to be inserted between message chunks.
Default is "\n".
Returns:
List of BaseMessages with consecutive runs of message types merged into single
messages. If two messages being merged both have string contents, the merged
content is a concatenation of the two strings with a new-line separator. If at
least one of the messages has a list of content blocks, the merged content is a
list of content blocks.
messages. By default, if two messages being merged both have string contents,
the merged content is a concatenation of the two strings with a new-line separator.
The separator inserted between message chunks can be controlled by specifying
any string with ``chunk_separator``. If at least one of the messages has a list of
content blocks, the merged content is a list of content blocks.
Example:
.. code-block:: python
@@ -509,11 +530,13 @@ def merge_message_runs(
and last_chunk.content
and curr_chunk.content
):
last_chunk.content += "\n"
last_chunk.content += chunk_separator
merged.append(_chunk_to_msg(last_chunk + curr_chunk))
return merged
# TODO: Update so validation errors (for token_counter, for example) are raised on
# init not at runtime.
@_runnable_support
def trim_messages(
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
@@ -759,24 +782,30 @@ def trim_messages(
AIMessage("This is a 4 token text. The full message is 10 tokens.", id="fourth"),
]
""" # noqa: E501
from langchain_core.language_models import BaseLanguageModel
if start_on and strategy == "first":
raise ValueError
if include_system and strategy == "first":
raise ValueError
messages = convert_to_messages(messages)
if isinstance(token_counter, BaseLanguageModel):
if hasattr(token_counter, "get_num_tokens_from_messages"):
list_token_counter = token_counter.get_num_tokens_from_messages
elif (
list(inspect.signature(token_counter).parameters.values())[0].annotation
is BaseMessage
):
elif callable(token_counter):
if (
list(inspect.signature(token_counter).parameters.values())[0].annotation
is BaseMessage
):
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
else:
list_token_counter = token_counter # type: ignore[assignment]
else:
list_token_counter = token_counter # type: ignore[assignment]
raise ValueError(
f"'token_counter' expected to be a model that implements "
f"'get_num_tokens_from_messages()' or a function. Received object of type "
f"{type(token_counter)}."
)
try:
from langchain_text_splitters import TextSplitter

View File

@@ -13,8 +13,6 @@ from typing import (
Union,
)
from typing_extensions import get_args
from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
@@ -166,10 +164,11 @@ class BaseOutputParser(
Raises:
TypeError: If the class doesn't have an inferable OutputType.
"""
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:
return type_args[0]
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) > 0:
return metadata["args"][0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import json
from json import JSONDecodeError
from typing import Any, List, Optional, Type, TypeVar, Union
from typing import Annotated, Any, List, Optional, Type, TypeVar, Union
import jsonpatch # type: ignore[import]
import pydantic # pydantic: ignore
from pydantic import SkipValidation # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
@@ -40,7 +41,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object.
"""
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore
"""The Pydantic object to use for validation.
If None, no validation is performed."""

View File

@@ -4,6 +4,7 @@ import re
from abc import abstractmethod
from collections import deque
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
from typing import Optional as Optional
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -122,6 +123,9 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
yield [part]
ListOutputParser.model_rebuild()
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""

View File

@@ -3,6 +3,7 @@ import json
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch # type: ignore[import]
from pydantic import BaseModel, model_validator
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
@@ -11,7 +12,6 @@ from langchain_core.output_parsers import (
)
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, root_validator
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
@@ -42,7 +42,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}")
raise OutputParserException(
f"Could not parse function call: {exc}"
) from exc
if self.args_only:
return func_call["arguments"]
@@ -100,7 +102,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
raise OutputParserException(
f"Could not parse function call: {exc}"
) from exc
try:
if partial:
try:
@@ -126,7 +130,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
) from exc
else:
try:
return {
@@ -138,7 +142,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
) from exc
except KeyError:
return None
@@ -226,8 +230,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
determine which schema to use.
"""
@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def validate_schema(cls, values: Dict) -> Any:
"""Validate the pydantic schema.
Args:
@@ -263,11 +268,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
if hasattr(self.pydantic_schema, "model_validate_json"):
pydantic_args = self.pydantic_schema.model_validate_json(_result)
else:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
else:
fn_name = _result["name"]
_args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
if hasattr(self.pydantic_schema, "model_validate_json"):
pydantic_args = self.pydantic_schema[fn_name].model_validate_json(_args) # type: ignore
else:
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
return pydantic_args

View File

@@ -1,19 +1,16 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, Dict, List, Optional
from typing import Annotated, Any, Dict, List, Optional
from pydantic import SkipValidation, ValidationError # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import (
invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.messages.tool import invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import ValidationError
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel
@@ -59,7 +56,7 @@ def parse_tool_call(
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
) from e
parsed = {
"name": raw_tool_call["function"]["name"] or "",
"args": function_args or {},
@@ -256,7 +253,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: List[TypeBaseModel]
tools: Annotated[List[TypeBaseModel], SkipValidation()]
"""The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all

View File

@@ -1,7 +1,8 @@
import json
from typing import Generic, List, Type
from typing import Annotated, Generic, List, Optional, Type
import pydantic # pydantic: ignore
from pydantic import SkipValidation # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
@@ -16,7 +17,7 @@ from langchain_core.utils.pydantic import (
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model."""
pydantic_object: Type[TBaseModel] # type: ignore
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
"""The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
@@ -32,12 +33,12 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
{self.pydantic_object.__class__}"
)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj)
raise self._parser_exception(e, obj) from e
else: # pydantic v1
try:
return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e:
raise self._parser_exception(e, obj)
raise self._parser_exception(e, obj) from e
def _parser_exception(
self, e: Exception, json_object: dict
@@ -49,7 +50,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> TBaseModel:
) -> Optional[TBaseModel]:
"""Parse the result of an LLM call to a pydantic object.
Args:
@@ -62,8 +63,13 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
Returns:
The parsed pydantic object.
"""
json_object = super().parse_result(result)
return self._parse_obj(json_object)
try:
json_object = super().parse_result(result)
return self._parse_obj(json_object)
except OutputParserException as e:
if partial:
return None
raise e
def parse(self, text: str) -> TBaseModel:
"""Parse the output of an LLM call to a pydantic object.
@@ -92,7 +98,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)
schema_str = json.dumps(reduced_schema, ensure_ascii=False)
return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
@@ -106,6 +112,9 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return self.pydantic_object
PydanticOutputParser.model_rebuild()
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}

View File

@@ -1,4 +1,5 @@
from typing import List
from typing import Optional as Optional
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -24,3 +25,6 @@ class StrOutputParser(BaseTransformOutputParser[str]):
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
StrOutputParser.model_rebuild()

View File

@@ -46,12 +46,12 @@ class _StreamingParser:
if parser == "defusedxml":
try:
from defusedxml import ElementTree as DET # type: ignore
except ImportError:
except ImportError as e:
raise ImportError(
"defusedxml is not installed. "
"Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml` "
)
) from e
_parser = DET.DefusedXMLParser(target=TreeBuilder())
else:
_parser = None
@@ -189,13 +189,13 @@ class XMLOutputParser(BaseTransformOutputParser):
if self.parser == "defusedxml":
try:
from defusedxml import ElementTree as DET # type: ignore
except ImportError:
except ImportError as e:
raise ImportError(
"defusedxml is not installed. "
"Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml`"
"See https://github.com/tiran/defusedxml for more details"
)
) from e
_ET = DET # Use the defusedxml parser
else:
_ET = ET # Use the standard library parser

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal, Union
from typing import List, Literal, Union
from pydantic import model_validator
from typing_extensions import Self
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts
@@ -30,8 +32,8 @@ class ChatGeneration(Generation):
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator(pre=False, skip_on_failure=True)
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="after")
def set_text(self) -> Self:
"""Set the text attribute to be the contents of the message.
Args:
@@ -45,12 +47,12 @@ class ChatGeneration(Generation):
"""
try:
text = ""
if isinstance(values["message"].content, str):
text = values["message"].content
if isinstance(self.message.content, str):
text = self.message.content
# HACK: Assumes text in content blocks in OpenAI format.
# Uses first text block.
elif isinstance(values["message"].content, list):
for block in values["message"].content:
elif isinstance(self.message.content, list):
for block in self.message.content:
if isinstance(block, str):
text = block
break
@@ -61,10 +63,10 @@ class ChatGeneration(Generation):
pass
else:
pass
values["text"] = text
self.text = text
except (KeyError, AttributeError) as e:
raise ValueError("Error while initializing ChatGeneration") from e
return values
return self
@classmethod
def get_lc_namespace(cls) -> List[str]:

View File

@@ -1,7 +1,8 @@
from typing import List, Optional
from pydantic import BaseModel
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
class ChatResult(BaseModel):

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional
from typing import List, Optional, Union
from langchain_core.outputs.generation import Generation
from pydantic import BaseModel
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
from langchain_core.outputs.generation import Generation, GenerationChunk
from langchain_core.outputs.run_info import RunInfo
from langchain_core.pydantic_v1 import BaseModel
class LLMResult(BaseModel):
@@ -16,7 +18,9 @@ class LLMResult(BaseModel):
wants to return.
"""
generations: List[List[Generation]]
generations: List[
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
]
"""Generated outputs.
The first dimension of the list represents completions for different input

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from uuid import UUID
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class RunInfo(BaseModel):

View File

@@ -18,6 +18,8 @@ from typing import (
)
import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
@@ -25,7 +27,6 @@ from langchain_core.prompt_values import (
PromptValue,
StringPromptValue,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.runnables.utils import create_model
@@ -64,28 +65,26 @@ class BasePromptTemplate(
tags: Optional[List[str]] = None
"""Tags to be used for tracing."""
@root_validator(pre=False, skip_on_failure=True)
def validate_variable_names(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_variable_names(self) -> Self:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
if "stop" in self.input_variables:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
if "stop" in self.partial_variables:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
overall = set(self.input_variables).intersection(self.partial_variables)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
return self
@classmethod
def get_lc_namespace(cls) -> List[str]:
@@ -99,8 +98,9 @@ class BasePromptTemplate(
Returns True."""
return True
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def OutputType(self) -> Any:
@@ -129,7 +129,7 @@ class BasePromptTemplate(
"PromptInput", **{**required_input_variables, **optional_input_variables}
)
def _validate_input(self, inner_input: Dict) -> Dict:
def _validate_input(self, inner_input: Any) -> Dict:
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
@@ -142,11 +142,18 @@ class BasePromptTemplate(
)
missing = set(self.input_variables).difference(inner_input)
if missing:
raise KeyError(
msg = (
f"Input to {self.__class__.__name__} is missing variables {missing}. "
f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}"
)
example_key = missing.pop()
msg += (
f"\nNote: if you intended {{{example_key}}} to be part of the string"
" and not a variable, please escape it with double curly braces like: "
f"'{{{{{example_key}}}}}'."
)
raise KeyError(msg)
return inner_input
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Annotated,
Any,
Dict,
List,
@@ -21,6 +22,13 @@ from typing import (
overload,
)
from pydantic import (
Field,
PositiveInt,
SkipValidation,
model_validator,
)
from langchain_core._api import deprecated
from langchain_core.load import Serializable
from langchain_core.messages import (
@@ -38,7 +46,6 @@ from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
@@ -207,8 +214,14 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
def __init__(
self, variable_name: str, *, optional: bool = False, **kwargs: Any
) -> None:
# mypy can't detect the init which is defined in the parent class
# b/c these are BaseModel classes.
super().__init__( # type: ignore
variable_name=variable_name, optional=optional, **kwargs
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
@@ -551,13 +564,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
input_variables=input_variables, template=img_template
)
else:
raise ValueError()
raise ValueError(f"Invalid image template: {tmpl}")
prompt.append(img_template_obj)
else:
raise ValueError()
raise ValueError(f"Invalid template: {tmpl}")
return cls(prompt=prompt, **kwargs)
else:
raise ValueError()
raise ValueError(f"Invalid template: {template}")
@classmethod
def from_template_file(
@@ -576,7 +589,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
Returns:
A new instance of this class.
"""
with open(str(template_file), "r") as f:
with open(str(template_file)) as f:
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs)
@@ -922,7 +935,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
""" # noqa: E501
messages: List[MessageLike]
messages: Annotated[List[MessageLike], SkipValidation]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
@@ -1038,8 +1051,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
@model_validator(mode="before")
@classmethod
def validate_input_variables(cls, values: dict) -> Any:
"""Validate input variables.
If input_variables is not set, it will be set to the union of
@@ -1177,7 +1191,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
A message can be represented using the following formats:
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
(message type, template); e.g., ("human", "{user_input}"),
(4) 2-tuple of (message class, template), (4) a string which is
(4) 2-tuple of (message class, template), (5) a string which is
shorthand for ("human", template); e.g., "{user_input}".
template_format: format of the template. Defaults to "f-string".

View File

@@ -5,6 +5,15 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import (
BaseModel,
ConfigDict,
Extra,
Field,
model_validator,
)
from typing_extensions import Self
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts.chat import (
@@ -18,7 +27,6 @@ from langchain_core.prompts.string import (
check_valid_template,
get_template_variables,
)
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
class _FewShotPromptTemplateMixin(BaseModel):
@@ -32,12 +40,14 @@ class _FewShotPromptTemplateMixin(BaseModel):
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra=Extra.forbid,
)
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def check_examples_and_selector(cls, values: Dict) -> Any:
"""Check that one and only one of examples/example_selector are provided.
Args:
@@ -139,28 +149,29 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
kwargs["input_variables"] = kwargs["example_prompt"].input_variables
super().__init__(**kwargs)
@root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def template_is_valid(self) -> Self:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
if self.validate_template:
check_valid_template(
values["prefix"] + values["suffix"],
values["template_format"],
values["input_variables"] + list(values["partial_variables"]),
self.prefix + self.suffix,
self.template_format,
self.input_variables + list(self.partial_variables),
)
elif values.get("template_format"):
values["input_variables"] = [
elif self.template_format or None:
self.input_variables = [
var
for var in get_template_variables(
values["prefix"] + values["suffix"], values["template_format"]
self.prefix + self.suffix, self.template_format
)
if var not in values["partial_variables"]
if var not in self.partial_variables
]
return values
return self
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra=Extra.forbid,
)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.
@@ -365,9 +376,10 @@ class FewShotChatMessagePromptTemplate(
"""Return whether or not the class is serializable."""
return False
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra=Extra.forbid,
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.

View File

@@ -3,12 +3,14 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from pydantic import ConfigDict, Extra, model_validator
from typing_extensions import Self
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
)
from langchain_core.pydantic_v1 import Extra, root_validator
class FewShotPromptWithTemplates(StringPromptTemplate):
@@ -45,8 +47,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "few_shot_with_templates"]
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def check_examples_and_selector(cls, values: Dict) -> Any:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
@@ -62,15 +65,15 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
return values
@root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def template_is_valid(self) -> Self:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
input_variables = values["input_variables"]
expected_input_variables = set(values["suffix"].input_variables)
expected_input_variables |= set(values["partial_variables"])
if values["prefix"] is not None:
expected_input_variables |= set(values["prefix"].input_variables)
if self.validate_template:
input_variables = self.input_variables
expected_input_variables = set(self.suffix.input_variables)
expected_input_variables |= set(self.partial_variables)
if self.prefix is not None:
expected_input_variables |= set(self.prefix.input_variables)
missing_vars = expected_input_variables.difference(input_variables)
if missing_vars:
raise ValueError(
@@ -78,16 +81,17 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
f"prefix/suffix expected {expected_input_variables}"
)
else:
values["input_variables"] = sorted(
set(values["suffix"].input_variables)
| set(values["prefix"].input_variables if values["prefix"] else [])
- set(values["partial_variables"])
self.input_variables = sorted(
set(self.suffix.input_variables)
| set(self.prefix.input_variables if self.prefix else [])
- set(self.partial_variables)
)
return values
return self
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra=Extra.forbid,
)
def _get_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:

View File

@@ -1,8 +1,9 @@
from typing import Any, List
from pydantic import Field
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import run_in_executor
from langchain_core.utils import image as image_utils

View File

@@ -173,7 +173,7 @@ def _load_prompt_from_file(
with open(file_path, encoding=encoding) as f:
config = json.load(f)
elif file_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, mode="r", encoding=encoding) as f:
with open(file_path, encoding=encoding) as f:
config = yaml.safe_load(f)
else:
raise ValueError(f"Got unsupported file type {file_path.suffix}")

View File

@@ -1,9 +1,11 @@
from typing import Any, Dict, List, Tuple
from typing import Optional as Optional
from pydantic import model_validator
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import BaseChatPromptTemplate
from langchain_core.pydantic_v1 import root_validator
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
@@ -34,8 +36,9 @@ class PipelinePromptTemplate(BasePromptTemplate):
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "pipeline"]
@root_validator(pre=True)
def get_input_variables(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def get_input_variables(cls, values: Dict) -> Any:
"""Get input variables."""
created_variables = set()
all_variables = set()
@@ -106,3 +109,6 @@ class PipelinePromptTemplate(BasePromptTemplate):
@property
def _prompt_type(self) -> str:
raise ValueError
PipelinePromptTemplate.model_rebuild()

View File

@@ -6,6 +6,8 @@ import warnings
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, model_validator
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
@@ -13,7 +15,6 @@ from langchain_core.prompts.string import (
get_template_variables,
mustache_schema,
)
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import RunnableConfig
@@ -73,8 +74,9 @@ class PromptTemplate(StringPromptTemplate):
validate_template: bool = False
"""Whether or not to try validating the template."""
@root_validator(pre=True)
def pre_init_validation(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def pre_init_validation(cls, values: Dict) -> Any:
"""Check that template and input variables are consistent."""
if values.get("template") is None:
# Will let pydantic fail with a ValidationError if template
@@ -231,11 +233,13 @@ class PromptTemplate(StringPromptTemplate):
Returns:
The prompt loaded from the file.
"""
with open(str(template_file), "r", encoding=encoding) as f:
with open(str(template_file), encoding=encoding) as f:
template = f.read()
if input_variables:
warnings.warn(
"`input_variables' is deprecated and ignored.", DeprecationWarning
"`input_variables' is deprecated and ignored.",
DeprecationWarning,
stacklevel=2,
)
return cls.from_template(template=template, **kwargs)

View File

@@ -7,10 +7,11 @@ from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set, Tuple, Type
from pydantic import BaseModel, create_model
import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
@@ -40,14 +41,14 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""
try:
from jinja2.sandbox import SandboxedEnvironment
except ImportError:
except ImportError as e:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
"Please be cautious when using jinja2 templates. "
"Do not expand jinja2 templates using unverified or user-controlled "
"inputs as that can result in arbitrary Python code execution."
)
) from e
# This uses a sandboxed environment to prevent arbitrary code execution.
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
@@ -81,17 +82,17 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
warning_message += f"Extra variables: {extra_variables}"
if warning_message:
warnings.warn(warning_message.strip())
warnings.warn(warning_message.strip(), stacklevel=7)
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
except ImportError as e:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
) from e
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)

View File

@@ -7,28 +7,25 @@ from typing import (
Mapping,
Optional,
Sequence,
Set,
Type,
Union,
)
from pydantic import BaseModel, Field
from langchain_core._api.beta_decorator import beta
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
MessageLikeRepresentation,
MessagesPlaceholder,
_convert_to_message,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Other,
Runnable,
RunnableSequence,
RunnableSerializable,
)
from langchain_core.utils import get_pydantic_field_names
@beta()
@@ -37,6 +34,26 @@ class StructuredPrompt(ChatPromptTemplate):
schema_: Union[Dict, Type[BaseModel]]
"""Schema for the structured prompt."""
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
def __init__(
self,
messages: Sequence[MessageLikeRepresentation],
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
structured_output_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
schema_ = schema_ or kwargs.pop("schema")
structured_output_kwargs = structured_output_kwargs or {}
for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
structured_output_kwargs[k] = kwargs.pop(k)
super().__init__(
messages=messages,
schema_=schema_,
structured_output_kwargs=structured_output_kwargs,
**kwargs,
)
@classmethod
def get_lc_namespace(cls) -> List[str]:
@@ -52,6 +69,7 @@ class StructuredPrompt(ChatPromptTemplate):
cls,
messages: Sequence[MessageLikeRepresentation],
schema: Union[Dict, Type[BaseModel]],
**kwargs: Any,
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@@ -61,11 +79,13 @@ class StructuredPrompt(ChatPromptTemplate):
.. code-block:: python
from langchain_core.prompts import StructuredPrompt
class OutputSchema(BaseModel):
name: str
value: int
template = ChatPromptTemplate.from_messages(
template = StructuredPrompt(
[
("human", "Hello, how are you?"),
("ai", "I'm doing well, thanks!"),
@@ -82,29 +102,13 @@ class StructuredPrompt(ChatPromptTemplate):
(4) 2-tuple of (message class, template), (5) a string which is
shorthand for ("human", template); e.g., "{user_input}"
schema: a dictionary representation of function call, or a Pydantic model.
kwargs: Any additional kwargs to pass through to
``ChatModel.with_structured_output(schema, **kwargs)``.
Returns:
a structured prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
partial_vars: Dict[str, Any] = {}
for _message in _messages:
if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
elif isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(
input_variables=sorted(input_vars),
messages=_messages,
partial_variables=partial_vars,
schema_=schema,
)
return cls(messages, schema, **kwargs)
def __or__(
self,
@@ -115,27 +119,16 @@ class StructuredPrompt(ChatPromptTemplate):
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSerializable[Dict, Other]:
if isinstance(other, BaseLanguageModel) or hasattr(
other, "with_structured_output"
):
try:
return RunnableSequence(
self, other.with_structured_output(self.schema_)
)
except NotImplementedError as e:
raise NotImplementedError(
"Structured prompts must be piped to a language model that "
"implements with_structured_output."
) from e
else:
raise NotImplementedError(
"Structured prompts must be piped to a language model that "
"implements with_structured_output."
)
return self.pipe(other)
def pipe(
self,
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
*others: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
name: Optional[str] = None,
) -> RunnableSerializable[Dict, Other]:
"""Pipe the structured prompt to a language model.
@@ -158,7 +151,9 @@ class StructuredPrompt(ChatPromptTemplate):
):
return RunnableSequence(
self,
others[0].with_structured_output(self.schema_),
others[0].with_structured_output(
self.schema_, **self.structured_output_kwargs
),
*others[1:],
name=name,
)

View File

@@ -181,7 +181,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
the caller should try again later.
"""
with self._consume_lock:
now = time.time()
now = time.monotonic()
# initialize on first call to avoid a burst
if self.last is None:

View File

@@ -26,6 +26,9 @@ from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import ConfigDict
from typing_extensions import TypedDict
from langchain_core._api import deprecated
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
@@ -50,6 +53,19 @@ RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
class LangSmithRetrieverParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""
ls_retriever_name: str
"""Retriever name."""
ls_vector_store_provider: Optional[str]
"""Vector store provider."""
ls_embedding_provider: Optional[str]
"""Embedding provider."""
ls_embedding_model: Optional[str]
"""Embedding model."""
class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""Abstract base class for a Document retrieval system.
@@ -111,8 +127,9 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
""" # noqa: E501
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
_new_arg_supported: bool = False
_expects_other_args: bool = False
@@ -140,6 +157,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
stacklevel=4,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
@@ -154,6 +172,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
stacklevel=4,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]
@@ -167,6 +186,19 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing."""
default_retriever_name = self.get_name()
if default_retriever_name.startswith("Retriever"):
default_retriever_name = default_retriever_name[9:]
elif default_retriever_name.endswith("Retriever"):
default_retriever_name = default_retriever_name[:-9]
default_retriever_name = default_retriever_name.lower()
ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
return ls_params
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
@@ -191,13 +223,17 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
from langchain_core.callbacks.manager import CallbackManager
config = ensure_config(config)
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(**kwargs),
}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags"),
local_tags=self.tags,
inheritable_metadata=config.get("metadata"),
inheritable_metadata=inheritable_metadata,
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
@@ -250,13 +286,17 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
from langchain_core.callbacks.manager import AsyncCallbackManager
config = ensure_config(config)
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(**kwargs),
}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags"),
local_tags=self.tags,
inheritable_metadata=config.get("metadata"),
inheritable_metadata=inheritable_metadata,
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
@@ -313,7 +353,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
run_manager=run_manager.get_sync(),
)
@deprecated(since="0.1.46", alternative="invoke", removal="0.3.0")
@deprecated(since="0.1.46", alternative="invoke", removal="1.0")
def get_relevant_documents(
self,
query: str,
@@ -357,7 +397,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
config["run_name"] = run_name
return self.invoke(query, config, **kwargs)
@deprecated(since="0.1.46", alternative="ainvoke", removal="0.3.0")
@deprecated(since="0.1.46", alternative="ainvoke", removal="1.0")
async def aget_relevant_documents(
self,
query: str,

View File

@@ -35,7 +35,8 @@ from typing import (
overload,
)
from typing_extensions import Literal, get_args
from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args, get_type_hints
from langchain_core._api import beta_decorator
from langchain_core.load.dump import dumpd
@@ -44,7 +45,6 @@ from langchain_core.load.serializable import (
SerializedConstructor,
SerializedNotImplemented,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
@@ -83,7 +83,6 @@ from langchain_core.runnables.utils import (
)
from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@@ -236,25 +235,56 @@ class Runnable(Generic[Input, Output], ABC):
For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/
""" # noqa: E501
name: Optional[str] = None
name: Optional[str]
"""The name of the Runnable. Used for debugging and tracing."""
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
"""Get the name of the Runnable."""
name = name or self.name or self.__class__.__name__
if suffix:
if name[0].isupper():
return name + suffix.title()
else:
return name + "_" + suffix.lower()
if name:
name_ = name
elif hasattr(self, "name") and self.name:
name_ = self.name
else:
return name
# Here we handle a case where the runnable subclass is also a pydantic
# model.
cls = self.__class__
# Then it's a pydantic sub-class, and we have to check
# whether it's a generic, and if so recover the original name.
if (
hasattr(
cls,
"__pydantic_generic_metadata__",
)
and "origin" in cls.__pydantic_generic_metadata__
and cls.__pydantic_generic_metadata__["origin"] is not None
):
name_ = cls.__pydantic_generic_metadata__["origin"].__name__
else:
name_ = cls.__name__
if suffix:
if name_[0].isupper():
return name_ + suffix.title()
else:
return name_ + "_" + suffix.lower()
else:
return name_
@property
def InputType(self) -> Type[Input]:
"""The type of input this Runnable accepts specified as a type annotation."""
# First loop through bases -- this will help generic
# any pydantic models.
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) == 2:
return metadata["args"][0]
# then loop through __orig_bases__ -- this will Runnables that do not inherit
# from pydantic
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
@@ -268,6 +298,14 @@ class Runnable(Generic[Input, Output], ABC):
@property
def OutputType(self) -> Type[Output]:
"""The type of output this Runnable produces specified as a type annotation."""
# First loop through bases -- this will help generic
# any pydantic models.
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) == 2:
return metadata["args"][1]
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
@@ -302,12 +340,12 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.InputType
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Input"),
__root__=(root_type, None),
__root__=root_type,
)
@property
@@ -334,12 +372,12 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.OutputType
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Output"),
__root__=(root_type, None),
__root__=root_type,
)
@property
@@ -381,15 +419,19 @@ class Runnable(Generic[Input, Output], ABC):
else None
)
return create_model( # type: ignore[call-overload]
self.get_name("Config"),
# Many need to create a typed dict instead to implement NotRequired!
all_fields = {
**({"configurable": (configurable, None)} if configurable else {}),
**{
field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items()
for field_name, field_type in get_type_hints(RunnableConfig).items()
if field_name in [i for i in include if i != "configurable"]
},
}
model = create_model( # type: ignore[call-overload]
self.get_name("Config"), **all_fields
)
return model
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
"""Return a graph representation of this Runnable."""
@@ -579,7 +621,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
from langchain_core.runnables.passthrough import RunnableAssign
return self | RunnableAssign(RunnableParallel(kwargs))
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
""" --- Public API --- """
@@ -2129,7 +2171,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
iterator_ = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
if accepts_config(transformer):
@@ -2314,7 +2355,6 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
"""Runnable that can be serialized to JSON."""
name: Optional[str] = None
"""The name of the Runnable. Used for debugging and tracing."""
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
"""Serialize the Runnable to JSON.
@@ -2369,10 +2409,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
from langchain_core.runnables.configurable import RunnableConfigurableFields
for key in kwargs:
if key not in self.__fields__:
if key not in self.model_fields:
raise ValueError(
f"Configuration key {key} not found in {self}: "
f"available keys are {self.__fields__.keys()}"
f"available keys are {self.model_fields.keys()}"
)
return RunnableConfigurableFields(default=self, fields=kwargs)
@@ -2447,13 +2487,13 @@ def _seq_input_schema(
return first.get_input_schema(config)
elif isinstance(first, RunnableAssign):
next_input_schema = _seq_input_schema(steps[1:], config)
if not next_input_schema.__custom_root_type__:
if not issubclass(next_input_schema, RootModel):
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceInput",
**{
k: (v.annotation, v.default)
for k, v in next_input_schema.__fields__.items()
for k, v in next_input_schema.model_fields.items()
if k not in first.mapper.steps__
},
)
@@ -2474,36 +2514,36 @@ def _seq_output_schema(
elif isinstance(last, RunnableAssign):
mapper_output_schema = last.mapper.get_output_schema(config)
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
if not issubclass(prev_output_schema, RootModel):
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
**{
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
for k, v in prev_output_schema.model_fields.items()
},
**{
k: (v.annotation, v.default)
for k, v in mapper_output_schema.__fields__.items()
for k, v in mapper_output_schema.model_fields.items()
},
},
)
elif isinstance(last, RunnablePick):
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
if not issubclass(prev_output_schema, RootModel):
# it's a dict as expected
if isinstance(last.keys, list):
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
for k, v in prev_output_schema.model_fields.items()
if k in last.keys
},
)
else:
field = prev_output_schema.__fields__[last.keys]
field = prev_output_schema.model_fields[last.keys]
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
__root__=(field.annotation, field.default),
@@ -2665,8 +2705,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
"""
return True
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Type[Input]:
@@ -3208,8 +3249,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
else:
final_pipeline = step.transform(final_pipeline, config)
for output in final_pipeline:
yield output
yield from final_pipeline
async def _atransform(
self,
@@ -3403,8 +3443,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
@@ -3451,7 +3492,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
**{
k: (v.annotation, v.default)
for step in self.steps__.values()
for k, v in step.get_input_schema(config).__fields__.items()
for k, v in step.get_input_schema(config).model_fields.items()
if k != "__root__"
},
)
@@ -3469,11 +3510,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
Returns:
The output schema of the Runnable.
"""
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
self.get_name("Output"),
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
)
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
return create_model(self.get_name("Output"), **fields)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
@@ -3883,6 +3921,8 @@ class RunnableGenerator(Runnable[Input, Output]):
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
*,
name: Optional[str] = None,
) -> None:
"""Initialize a RunnableGenerator.
@@ -3910,13 +3950,13 @@ class RunnableGenerator(Runnable[Input, Output]):
)
try:
self.name = func_for_name.__name__
self.name = name or func_for_name.__name__
except AttributeError:
pass
self.name = "RunnableGenerator"
@property
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
func = getattr(self, "_transform", None) or self._atransform
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
@@ -3929,7 +3969,7 @@ class RunnableGenerator(Runnable[Input, Output]):
@property
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
func = getattr(self, "_transform", None) or self._atransform
try:
sig = inspect.signature(func)
return (
@@ -4153,7 +4193,7 @@ class RunnableLambda(Runnable[Input, Output]):
@property
def InputType(self) -> Any:
"""The type of the input to this Runnable."""
func = getattr(self, "func", None) or getattr(self, "afunc")
func = getattr(self, "func", None) or self.afunc
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
@@ -4175,7 +4215,7 @@ class RunnableLambda(Runnable[Input, Output]):
Returns:
The input schema for this Runnable.
"""
func = getattr(self, "func", None) or getattr(self, "afunc")
func = getattr(self, "func", None) or self.afunc
if isinstance(func, itemgetter):
# This is terrible, but afaict it's not possible to access _items
@@ -4184,15 +4224,13 @@ class RunnableLambda(Runnable[Input, Output]):
if all(
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
):
fields = {item[1:-1]: (Any, ...) for item in items}
# It's a dict, lol
return create_model(
self.get_name("Input"),
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
)
return create_model(self.get_name("Input"), **fields)
else:
return create_model(
self.get_name("Input"),
__root__=(List[Any], None),
__root__=List[Any],
)
if self.InputType != Any:
@@ -4201,7 +4239,7 @@ class RunnableLambda(Runnable[Input, Output]):
if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
self.get_name("Input"),
**{key: (Any, None) for key in dict_keys}, # type: ignore
**{key: (Any, ...) for key in dict_keys}, # type: ignore
)
return super().get_input_schema(config)
@@ -4213,7 +4251,7 @@ class RunnableLambda(Runnable[Input, Output]):
Returns:
The type of the output of this Runnable.
"""
func = getattr(self, "func", None) or getattr(self, "afunc")
func = getattr(self, "func", None) or self.afunc
try:
sig = inspect.signature(func)
if sig.return_annotation != inspect.Signature.empty:
@@ -4577,13 +4615,12 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> Iterator[Output]:
if hasattr(self, "func"):
for output in self._transform_stream_with_config(
yield from self._transform_stream_with_config(
input,
self._transform,
self._config(config, self.func),
**kwargs,
):
yield output
)
else:
raise TypeError(
"Cannot stream a coroutine function synchronously."
@@ -4730,8 +4767,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Any:
@@ -4758,10 +4796,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
schema = self.bound.get_output_schema(config)
return create_model(
self.get_name("Output"),
__root__=(
List[schema], # type: ignore
None,
),
__root__=List[schema], # type: ignore[valid-type]
)
@property
@@ -4981,8 +5016,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
The type can be a pydantic model, or a type annotation (e.g., `List[str]`).
"""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def __init__(
self,
@@ -5318,7 +5354,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
yield item
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
RunnableBindingBase.model_rebuild()
class RunnableBinding(RunnableBindingBase[Input, Output]):
@@ -5334,12 +5370,13 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
`RunnableWithFallbacks`) that add additional functionality.
These methods include:
- `bind`: Bind kwargs to pass to the underlying Runnable when running it.
- `with_config`: Bind config to pass to the underlying Runnable when running it.
- `with_listeners`: Bind lifecycle listeners to the underlying Runnable.
- `with_types`: Override the input and output types of the underlying Runnable.
- `with_retry`: Bind a retry policy to the underlying Runnable.
- `with_fallbacks`: Bind a fallback policy to the underlying Runnable.
- ``bind``: Bind kwargs to pass to the underlying Runnable when running it.
- ``with_config``: Bind config to pass to the underlying Runnable when running it.
- ``with_listeners``: Bind lifecycle listeners to the underlying Runnable.
- ``with_types``: Override the input and output types of the underlying Runnable.
- ``with_retry``: Bind a retry policy to the underlying Runnable.
- ``with_fallbacks``: Bind a fallback policy to the underlying Runnable.
Example:

View File

@@ -14,8 +14,9 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Runnable,
RunnableLike,
@@ -134,10 +135,21 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_) # type: ignore[call-arg]
super().__init__(
branches=_branches,
default=default_,
# Hard-coding a name here because RunnableBranch is a generic
# and with pydantic 2, the class name with pydantic will capture
# include the parameterized type, which is not what we want.
# e.g., we'd get RunnableBranch[Input, Output] instead of RunnableBranch
# for the name. This information is already captured in the
# input and output types.
name="RunnableBranch",
) # type: ignore[call-arg]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@@ -236,6 +236,7 @@ def get_config_list(
warnings.warn(
"Provided run_id be used only for the first element of the batch.",
category=RuntimeWarning,
stacklevel=3,
)
subsequent = cast(
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
@@ -348,37 +349,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
manager = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id,
handlers=[],
inheritable_handlers=[],
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
inheritable_tags=list(
set(
base_callbacks.inheritable_tags
+ these_callbacks.inheritable_tags
)
),
metadata={
**base_callbacks.metadata,
**these_callbacks.metadata,
},
)
handlers = base_callbacks.handlers + these_callbacks.handlers
inheritable_handlers = (
base_callbacks.inheritable_handlers
+ these_callbacks.inheritable_handlers
)
for handler in handlers:
manager.add_handler(handler)
for handler in inheritable_handlers:
manager.add_handler(handler, inherit=True)
base["callbacks"] = manager
base["callbacks"] = base_callbacks.merge(these_callbacks)
elif key == "recursion_limit":
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
base["recursion_limit"] = config["recursion_limit"]

View File

@@ -20,7 +20,8 @@ from typing import (
)
from weakref import WeakValueDictionary
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, ConfigDict
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
@@ -58,8 +59,9 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def is_lc_serializable(cls) -> bool:
@@ -373,28 +375,33 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
Returns:
List[ConfigurableFieldSpec]: The configuration specs.
"""
return get_unique_config_specs(
[
(
# TODO(0.3): This change removes field_info which isn't needed in pydantic 2
config_specs = []
for field_name, spec in self.fields.items():
if isinstance(spec, ConfigurableField):
config_specs.append(
ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description
or self.default.__fields__[field_name].field_info.description,
or self.default.model_fields[field_name].description,
annotation=spec.annotation
or self.default.__fields__[field_name].annotation,
or self.default.model_fields[field_name].annotation,
default=getattr(self.default, field_name),
is_shared=spec.is_shared,
)
if isinstance(spec, ConfigurableField)
else make_options_spec(
spec, self.default.__fields__[field_name].field_info.description
)
else:
config_specs.append(
make_options_spec(
spec, self.default.model_fields[field_name].description
)
)
for field_name, spec in self.fields.items()
]
+ list(self.default.config_specs)
)
config_specs.extend(self.default.config_specs)
return get_unique_config_specs(config_specs)
def configurable_fields(
self, **kwargs: AnyConfigurableField
@@ -436,7 +443,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
init_params = {
k: v
for k, v in self.default.__dict__.items()
if k in self.default.__fields__
if k in self.default.model_fields
}
return (
self.default.__class__(**{**init_params, **configurable}),

View File

@@ -18,8 +18,9 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
@@ -107,8 +108,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
will not be passed to fallbacks. If used, the base Runnable and its fallbacks
must accept a dictionary as input."""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Type[Input]:
@@ -180,6 +182,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
output = context.run(
runnable.invoke,
input,
config,
**kwargs,
)
except self.exceptions_to_handle as e:

View File

@@ -22,8 +22,9 @@ from typing import (
)
from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
@@ -235,7 +236,9 @@ def node_data_json(
json = (
{
"type": "schema",
"data": node.data.schema(),
"data": node.data.model_json_schema(
schema_generator=_IgnoreUnserializable
),
}
if with_schemas
else {
@@ -537,7 +540,7 @@ class Graph:
*,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeStyles = NodeStyles(),
node_colors: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
) -> str:
"""Draw the graph as a Mermaid syntax string.
@@ -573,7 +576,7 @@ class Graph:
self,
*,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeStyles = NodeStyles(),
node_colors: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,

View File

@@ -20,7 +20,7 @@ def draw_mermaid(
last_node: Optional[str] = None,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_styles: NodeStyles = NodeStyles(),
node_styles: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
) -> str:
"""Draws a Mermaid graph using the provided graph data.
@@ -78,51 +78,82 @@ def draw_mermaid(
)
mermaid_graph += f"\t{node_label}\n"
subgraph = ""
# Add edges to the graph
# Group edges by their common prefixes
edge_groups: Dict[str, List[Edge]] = {}
for edge in edges:
src_prefix = edge.source.split(":")[0] if ":" in edge.source else None
tgt_prefix = edge.target.split(":")[0] if ":" in edge.target else None
# exit subgraph if source or target is not in the same subgraph
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
mermaid_graph += "\tend\n"
subgraph = ""
# enter subgraph if source and target are in the same subgraph
if not subgraph and src_prefix and src_prefix == tgt_prefix:
mermaid_graph += f"\tsubgraph {src_prefix}\n"
subgraph = src_prefix
source, target = edge.source, edge.target
# Add BR every wrap_label_n_words words
if edge.data is not None:
edge_data = edge.data
words = str(edge_data).split() # Split the string into words
# Group words into chunks of wrap_label_n_words size
if len(words) > wrap_label_n_words:
edge_data = "&nbsp<br>&nbsp".join(
" ".join(words[i : i + wrap_label_n_words])
for i in range(0, len(words), wrap_label_n_words)
)
if edge.conditional:
edge_label = f" -. &nbsp{edge_data}&nbsp .-> "
else:
edge_label = f" -- &nbsp{edge_data}&nbsp --> "
else:
if edge.conditional:
edge_label = " -.-> "
else:
edge_label = " --> "
mermaid_graph += (
f"\t{_escape_node_label(source)}{edge_label}"
f"{_escape_node_label(target)};\n"
src_parts = edge.source.split(":")
tgt_parts = edge.target.split(":")
common_prefix = ":".join(
src for src, tgt in zip(src_parts, tgt_parts) if src == tgt
)
if subgraph:
mermaid_graph += "end\n"
edge_groups.setdefault(common_prefix, []).append(edge)
seen_subgraphs = set()
def add_subgraph(edges: List[Edge], prefix: str) -> None:
nonlocal mermaid_graph
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
if prefix and not self_loop:
subgraph = prefix.split(":")[-1]
if subgraph in seen_subgraphs:
raise ValueError(
f"Found duplicate subgraph '{subgraph}' -- this likely means that "
"you're reusing a subgraph node with the same name. "
"Please adjust your graph to have subgraph nodes with unique names."
)
seen_subgraphs.add(subgraph)
mermaid_graph += f"\tsubgraph {subgraph}\n"
for edge in edges:
source, target = edge.source, edge.target
# Add BR every wrap_label_n_words words
if edge.data is not None:
edge_data = edge.data
words = str(edge_data).split() # Split the string into words
# Group words into chunks of wrap_label_n_words size
if len(words) > wrap_label_n_words:
edge_data = "&nbsp<br>&nbsp".join(
" ".join(words[i : i + wrap_label_n_words])
for i in range(0, len(words), wrap_label_n_words)
)
if edge.conditional:
edge_label = f" -. &nbsp{edge_data}&nbsp .-> "
else:
edge_label = f" -- &nbsp{edge_data}&nbsp --> "
else:
if edge.conditional:
edge_label = " -.-> "
else:
edge_label = " --> "
mermaid_graph += (
f"\t{_escape_node_label(source)}{edge_label}"
f"{_escape_node_label(target)};\n"
)
# Recursively add nested subgraphs
for nested_prefix in edge_groups.keys():
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
continue
add_subgraph(edge_groups[nested_prefix], nested_prefix)
if prefix and not self_loop:
mermaid_graph += "\tend\n"
# Start with the top-level edges (no common prefix)
add_subgraph(edge_groups.get("", []), "")
# Add remaining subgraphs
for prefix in edge_groups.keys():
if ":" in prefix or prefix == "":
continue
add_subgraph(edge_groups[prefix], prefix)
# Add custom styles for nodes
if with_styles:
mermaid_graph += _generate_mermaid_graph_styles(node_styles)
mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles())
return mermaid_graph

View File

@@ -13,10 +13,10 @@ from typing import (
Union,
)
from pydantic import BaseModel
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableBranch
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
@@ -320,17 +320,22 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain}
).with_config(run_name="insert_history")
runnable_sync: Runnable = runnable.with_listeners(on_end=self._exit_history)
runnable_async: Runnable = runnable.with_alisteners(on_end=self._aexit_history)
def _call_runnable_sync(_input: Any) -> Runnable:
return runnable_sync
async def _call_runnable_async(_input: Any) -> Runnable:
return runnable_async
bound: Runnable = (
history_chain
| RunnableBranch(
(
RunnableLambda(
self._is_not_async, afunc=self._is_async
).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"),
runnable.with_alisteners(on_end=self._aexit_history),
),
runnable.with_listeners(on_end=self._exit_history),
)
| RunnableLambda(
_call_runnable_sync,
_call_runnable_async,
).with_config(run_name="check_sync_or_async")
).with_config(run_name="RunnableWithMessageHistory")
if history_factory_config:
@@ -368,28 +373,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
super_schema = super().get_input_schema(config)
if super_schema.__custom_root_type__ or not super_schema.schema().get(
"properties"
):
from langchain_core.messages import BaseMessage
# TODO(0.3): Verify that this change was correct
# Not enough tests and unclear on why the previous implementation was
# necessary.
from langchain_core.messages import BaseMessage
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
fields["__root__"] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
return super_schema
fields["__root__"] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
)
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return False
@@ -429,7 +431,9 @@ class RunnableWithMessageHistory(RunnableBindingBase):
# This occurs for chat models - since we batch inputs
if isinstance(input_val[0], list):
if len(input_val) != 1:
raise ValueError()
raise ValueError(
f"Expected a single list of messages. Got {input_val}."
)
return input_val[0]
return list(input_val)
else:
@@ -468,7 +472,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
elif isinstance(output_val, (list, tuple)):
return list(output_val)
else:
raise ValueError()
raise ValueError(
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {output_val}."
)
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]

View File

@@ -21,7 +21,8 @@ from typing import (
cast,
)
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, RootModel
from langchain_core.runnables.base import (
Other,
Runnable,
@@ -216,7 +217,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
],
],
) -> "RunnableAssign":
) -> RunnableAssign:
"""Merge the Dict input with the output produced by the mapping argument.
Args:
@@ -227,7 +228,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
A Runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableParallel(kwargs))
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
@@ -419,7 +420,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
if not map_input_schema.__custom_root_type__:
if not issubclass(map_input_schema, RootModel):
# ie. it's a dict
return map_input_schema
@@ -430,20 +431,22 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
if not issubclass(map_input_schema, RootModel) and not issubclass(
map_output_schema, RootModel
):
# ie. both are dicts
fields = {}
for name, field_info in map_input_schema.model_fields.items():
fields[name] = (field_info.annotation, field_info.default)
for name, field_info in map_output_schema.model_fields.items():
fields[name] = (field_info.annotation, field_info.default)
return create_model( # type: ignore[call-overload]
"RunnableAssignOutput",
**{
k: (v.type_, v.default)
for s in (map_input_schema, map_output_schema)
for k, v in s.__fields__.items()
},
**fields,
)
elif not map_output_schema.__custom_root_type__:
elif not issubclass(map_output_schema, RootModel):
# ie. only map output is a dict
# ie. input type is either unknown or inferred incorrectly
return map_output_schema

View File

@@ -147,7 +147,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
retry_state: RetryCallState,
) -> RunnableConfig:
attempt = retry_state.attempt_number
tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None
tag = f"retry:attempt:{attempt}" if attempt > 1 else None
return patch_config(config, callbacks=run_manager.get_child(tag))
def _patch_config_list(
@@ -218,6 +218,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
result = not_set
try:
for attempt in self._sync_retrying():
with attempt:
@@ -247,9 +249,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
if result is not_set:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
@@ -284,6 +284,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
result = not_set
try:
async for attempt in self._async_retrying():
with attempt:
@@ -313,9 +315,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
if result is not_set:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []

View File

@@ -12,6 +12,7 @@ from typing import (
cast,
)
from pydantic import ConfigDict
from typing_extensions import TypedDict
from langchain_core.runnables.base import (
@@ -83,8 +84,9 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@@ -28,12 +28,18 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
from pydantic import BaseModel, ConfigDict, RootModel # pydantic: ignore
from pydantic import create_model as _create_model_base # pydantic :ignore
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)
from typing_extensions import TypeGuard
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
from langchain_core.pydantic_v1 import create_model as _create_model_base
from langchain_core.runnables.schema import StreamEvent
Input = TypeVar("Input", contravariant=True)
@@ -350,7 +356,7 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
tree = ast.parse(textwrap.dedent(code))
visitor = IsFunctionArgDict()
visitor.visit(tree)
return list(visitor.keys) if visitor.keys else None
return sorted(visitor.keys) if visitor.keys else None
except (SyntaxError, TypeError, OSError, SystemError):
return None
@@ -393,7 +399,9 @@ def get_function_nonlocals(func: Callable) -> List[Any]:
visitor = FunctionNonLocals()
visitor.visit(tree)
values: List[Any] = []
for k, v in inspect.getclosurevars(func).nonlocals.items():
closure = inspect.getclosurevars(func)
candidates = {**closure.globals, **closure.nonlocals}
for k, v in candidates.items():
if k in visitor.nonlocals:
values.append(v)
for kk in visitor.nonlocals:
@@ -697,9 +705,57 @@ class _RootEventFilter:
return include
class _SchemaConfig(BaseConfig):
arbitrary_types_allowed = True
frozen = True
_SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
NO_DEFAULT = object()
def create_base_class(
name: str, type_: Any, default_: object = NO_DEFAULT
) -> Type[BaseModel]:
"""Create a base class."""
def schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_
def model_json_schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_
base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
"__module__": "langchain_core.runnables.utils",
}
if default_ is not NO_DEFAULT:
base_class_attributes["root"] = default_
custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(Type[BaseModel], custom_root_type)
def create_model(
@@ -715,6 +771,21 @@ def create_model(
Returns:
Type[BaseModel]: The created model.
"""
# Move this to caching path
if "__root__" in field_definitions:
if len(field_definitions) > 1:
raise NotImplementedError(
"When specifying __root__ no other "
f"fields should be provided. Got {field_definitions}"
)
arg = field_definitions["__root__"]
if isinstance(arg, tuple):
named_root_model = create_base_class(__model_name, arg[0], arg[1])
else:
named_root_model = create_base_class(__model_name, arg)
return named_root_model
try:
return _create_model_cached(__model_name, **field_definitions)
except TypeError:
@@ -748,7 +819,7 @@ def is_async_generator(
"""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
or hasattr(func, "__call__") # noqa: B004
and inspect.isasyncgenfunction(func.__call__)
)
@@ -767,6 +838,6 @@ def is_async_callable(
"""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__")
or hasattr(func, "__call__") # noqa: B004
and asyncio.iscoroutinefunction(func.__call__)
)

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Sequence, Union
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class Visitor(ABC):
@@ -127,7 +127,8 @@ class Comparison(FilterDirective):
def __init__(
self, comparator: Comparator, attribute: str, value: Any, **kwargs: Any
) -> None:
super().__init__(
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
comparator=comparator, attribute=attribute, value=value, **kwargs
)
@@ -145,8 +146,11 @@ class Operation(FilterDirective):
def __init__(
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any
):
super().__init__(operator=operator, arguments=arguments, **kwargs)
) -> None:
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
operator=operator, arguments=arguments, **kwargs
)
class StructuredQuery(Expr):
@@ -165,5 +169,8 @@ class StructuredQuery(Expr):
filter: Optional[FilterDirective],
limit: Optional[int] = None,
**kwargs: Any,
):
super().__init__(query=query, filter=filter, limit=limit, **kwargs)
) -> None:
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
query=query, filter=filter, limit=limit, **kwargs
)

File diff suppressed because it is too large Load Diff

View File

@@ -516,15 +516,19 @@ class _TracerCore(ABC):
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""End a trace for a run."""
return None
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process a run upon creation."""
return None
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process a run upon update."""
return None
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run upon start."""
return None
def _on_llm_new_token(
self,
@@ -533,39 +537,52 @@ class _TracerCore(ABC):
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> Union[None, Coroutine[Any, Any, None]]:
"""Process new LLM token."""
return None
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run."""
return None
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run upon error."""
return None
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run upon start."""
return None
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run."""
return None
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run upon error."""
return None
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run upon start."""
return None
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run."""
return None
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run upon error."""
return None
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chat Model Run upon start."""
return None
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run upon start."""
return None
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run."""
return None
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run upon error."""
return None

View File

@@ -144,11 +144,7 @@ class EvaluatorCallbackHandler(BaseTracer):
example_id = str(run.reference_example_id)
with self.lock:
for res in eval_results:
run_id = (
str(getattr(res, "target_run_id"))
if hasattr(res, "target_run_id")
else str(run.id)
)
run_id = str(getattr(res, "target_run_id", run.id))
self.logged_eval_results.setdefault((run_id, example_id), []).append(
res
)
@@ -179,11 +175,9 @@ class EvaluatorCallbackHandler(BaseTracer):
source_info_: Dict[str, Any] = {}
if res.evaluator_info:
source_info_ = {**res.evaluator_info, **source_info_}
run_id_ = (
getattr(res, "target_run_id")
if hasattr(res, "target_run_id") and res.target_run_id is not None
else run.id
)
run_id_ = getattr(res, "target_run_id", None)
if run_id_ is None:
run_id_ = run.id
self.client.create_feedback(
run_id_,
res.key,

View File

@@ -11,22 +11,22 @@ from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from langchain_core._api import deprecated
from langchain_core.outputs import LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@deprecated("0.1.0", alternative="Use string instead.", removal="0.3.0")
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
def RunTypeEnum() -> Type[RunTypeEnumDep]:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning,
stacklevel=2,
)
return RunTypeEnumDep
@deprecated("0.1.0", removal="0.3.0")
@deprecated("0.1.0", removal="1.0")
class TracerSessionV1Base(BaseModel):
"""Base class for TracerSessionV1."""
@@ -35,33 +35,33 @@ class TracerSessionV1Base(BaseModel):
extra: Optional[Dict[str, Any]] = None
@deprecated("0.1.0", removal="0.3.0")
@deprecated("0.1.0", removal="1.0")
class TracerSessionV1Create(TracerSessionV1Base):
"""Create class for TracerSessionV1."""
@deprecated("0.1.0", removal="0.3.0")
@deprecated("0.1.0", removal="1.0")
class TracerSessionV1(TracerSessionV1Base):
"""TracerSessionV1 schema."""
id: int
@deprecated("0.1.0", removal="0.3.0")
@deprecated("0.1.0", removal="1.0")
class TracerSessionBase(TracerSessionV1Base):
"""Base class for TracerSession."""
tenant_id: UUID
@deprecated("0.1.0", removal="0.3.0")
@deprecated("0.1.0", removal="1.0")
class TracerSession(TracerSessionBase):
"""TracerSessionV1 schema for the V2 API."""
id: UUID
@deprecated("0.1.0", alternative="Run", removal="0.3.0")
@deprecated("0.1.0", alternative="Run", removal="1.0")
class BaseRun(BaseModel):
"""Base class for Run."""
@@ -77,15 +77,16 @@ class BaseRun(BaseModel):
error: Optional[str] = None
@deprecated("0.1.0", alternative="Run", removal="0.3.0")
@deprecated("0.1.0", alternative="Run", removal="1.0")
class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
response: Optional[LLMResult] = None
# Temporarily, remove but we will completely remove LLMRun
# response: Optional[LLMResult] = None
@deprecated("0.1.0", alternative="Run", removal="0.3.0")
@deprecated("0.1.0", alternative="Run", removal="1.0")
class ChainRun(BaseRun):
"""Class for ChainRun."""
@@ -96,7 +97,7 @@ class ChainRun(BaseRun):
child_tool_runs: List[ToolRun] = Field(default_factory=list)
@deprecated("0.1.0", alternative="Run", removal="0.3.0")
@deprecated("0.1.0", alternative="Run", removal="1.0")
class ToolRun(BaseRun):
"""Class for ToolRun."""

View File

@@ -22,10 +22,12 @@ from langchain_core.utils.utils import (
build_extra_kwargs,
check_package_version,
convert_to_secret_str,
from_env,
get_pydantic_field_names,
guard_import,
mock_now,
raise_for_status_with_text,
secret_from_env,
xor_args,
)
@@ -54,4 +56,6 @@ __all__ = [
"pre_init",
"batch_iterate",
"abatch_iterate",
"from_env",
"secret_from_env",
]

View File

@@ -41,6 +41,19 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
" but with a different type."
)
elif isinstance(merged[right_k], str):
# TODO: Add below special handling for 'type' key in 0.3 and remove
# merge_lists 'type' logic.
#
# if right_k == "type":
# if merged[right_k] == right_v:
# continue
# else:
# raise ValueError(
# "Unable to merge. Two different values seen for special "
# f"key 'type': {merged[right_k]} and {right_v}. 'type' "
# "should either occur once or have the same value across "
# "all dicts."
# )
merged[right_k] += right_v
elif isinstance(merged[right_k], dict):
merged[right_k] = merge_dicts(merged[right_k], right_v)
@@ -81,10 +94,10 @@ def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]
if e_left["index"] == e["index"]
]
if to_merge:
# If a top-level "type" has been set for a chunk, it should no
# longer be overridden by the "type" field in future chunks.
if "type" in merged[to_merge[0]] and "type" in e:
e.pop("type")
# TODO: Remove this once merge_dict is updated with special
# handling for 'type'.
if "type" in e:
e = {k: v for k, v in e.items() if k != "type"}
merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e)
else:
merged.append(e)

View File

@@ -62,8 +62,8 @@ def py_anext(
__anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
)
except AttributeError:
raise TypeError(f"{iterator!r} is not an async iterator")
except AttributeError as e:
raise TypeError(f"{iterator!r} is not an async iterator") from e
if default is _no_default:
return __anext__(iterator)

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import collections
import inspect
import logging
import types
import typing
import uuid
from typing import (
@@ -22,11 +23,11 @@ from typing import (
cast,
)
from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
from langchain_core._api import deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -81,10 +82,10 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
removal="0.3.0",
removal="1.0",
)
def convert_pydantic_to_openai_function(
model: Type[BaseModel],
model: Type,
*,
name: Optional[str] = None,
description: Optional[str] = None,
@@ -108,7 +109,10 @@ def convert_pydantic_to_openai_function(
else:
schema = model.schema() # Pydantic 1
schema = dereference_refs(schema)
schema.pop("definitions", None)
if "definitions" in schema: # pydantic 1
schema.pop("definitions", None)
if "$defs" in schema: # pydantic 2
schema.pop("$defs", None)
title = schema.pop("title", "")
default_description = schema.pop("description", "")
return {
@@ -121,7 +125,7 @@ def convert_pydantic_to_openai_function(
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_tool()",
removal="0.3.0",
removal="1.0",
)
def convert_pydantic_to_openai_tool(
model: Type[BaseModel],
@@ -155,7 +159,7 @@ def _get_python_function_name(function: Callable) -> str:
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
removal="0.3.0",
removal="1.0",
)
def convert_python_function_to_openai_function(
function: Callable,
@@ -172,10 +176,10 @@ def convert_python_function_to_openai_function(
Returns:
The OpenAI function description.
"""
from langchain_core import tools
from langchain_core.tools.base import create_schema_from_function
func_name = _get_python_function_name(function)
model = tools.create_schema_from_function(
model = create_schema_from_function(
func_name,
function,
filter_args=(),
@@ -192,11 +196,13 @@ def convert_python_function_to_openai_function(
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
visited: Dict = {}
from pydantic.v1 import BaseModel # pydantic: ignore
model = cast(
Type[BaseModel],
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
)
return convert_pydantic_to_openai_function(model)
return convert_pydantic_to_openai_function(model) # type: ignore
_MAX_TYPED_DICT_RECURSION = 25
@@ -208,6 +214,9 @@ def _convert_any_typed_dicts_to_pydantic(
visited: Dict,
depth: int = 0,
) -> Type:
from pydantic.v1 import Field as Field_v1 # pydantic: ignore
from pydantic.v1 import create_model as create_model_v1 # pydantic: ignore
if type_ in visited:
return visited[type_]
elif depth >= _MAX_TYPED_DICT_RECURSION:
@@ -241,7 +250,7 @@ def _convert_any_typed_dicts_to_pydantic(
field_kwargs["description"] = arg_desc
else:
pass
fields[arg] = (new_arg_type, Field(**field_kwargs))
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
else:
new_arg_type = _convert_any_typed_dicts_to_pydantic(
arg_type, depth=depth + 1, visited=visited
@@ -249,8 +258,8 @@ def _convert_any_typed_dicts_to_pydantic(
field_kwargs = {"default": ...}
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field(**field_kwargs))
model = create_model(typed_dict.__name__, **fields)
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
model = create_model_v1(typed_dict.__name__, **fields)
model.__doc__ = description
visited[typed_dict] = model
return model
@@ -268,7 +277,7 @@ def _convert_any_typed_dicts_to_pydantic(
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
removal="0.3.0",
removal="1.0",
)
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tool into the OpenAI function API.
@@ -305,7 +314,7 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_tool()",
removal="0.3.0",
removal="1.0",
)
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
"""Format tool into the OpenAI function API.
@@ -327,15 +336,23 @@ def convert_to_openai_function(
) -> Dict[str, Any]:
"""Convert a raw function/class to an OpenAI function.
.. versionchanged:: 0.2.29
``strict`` arg added.
Args:
function: A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain
function:
A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain
Tool object, or a Python function. If a dictionary is passed in, it is
assumed to already be a valid OpenAI function or a JSON schema with
top-level 'title' and 'description' keys specified.
strict: If True, model output is guaranteed to exactly match the JSON Schema
strict:
If True, model output is guaranteed to exactly match the JSON Schema
provided in the function definition. If None, ``strict`` argument will not
be included in function definition.
.. versionadded:: 0.2.29
Returns:
A dict version of the passed in function which is compatible with the OpenAI
function-calling API.
@@ -380,9 +397,13 @@ def convert_to_openai_function(
if strict is not None:
oai_function["strict"] = strict
# As of 08/06/24, OpenAI requires that additionalProperties be supplied and set
# to False if strict is True.
oai_function["parameters"]["additionalProperties"] = False
if strict:
# As of 08/06/24, OpenAI requires that additionalProperties be supplied and
# set to False if strict is True.
# All properties layer needs 'additionalProperties=False'
oai_function["parameters"] = _recursive_set_additional_properties_false(
oai_function["parameters"]
)
return oai_function
@@ -393,18 +414,26 @@ def convert_to_openai_tool(
) -> Dict[str, Any]:
"""Convert a raw function/class to an OpenAI tool.
.. versionchanged:: 0.2.29
``strict`` arg added.
Args:
tool: Either a dictionary, a pydantic.BaseModel class, Python function, or
tool:
Either a dictionary, a pydantic.BaseModel class, Python function, or
BaseTool. If a dictionary is passed in, it is assumed to already be a valid
OpenAI tool, OpenAI function, or a JSON schema with top-level 'title' and
'description' keys specified.
strict: If True, model output is guaranteed to exactly match the JSON Schema
strict:
If True, model output is guaranteed to exactly match the JSON Schema
provided in the function definition. If None, ``strict`` argument will not
be included in tool definition.
.. versionadded:: 0.2.29
Returns:
A dict version of the passed in tool which is compatible with the
OpenAI tool-calling API.
OpenAI tool-calling API.
"""
if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool:
return tool
@@ -559,6 +588,10 @@ def _parse_google_docstring(
def _py_38_safe_origin(origin: Type) -> Type:
origin_union_type_map: Dict[Type, Any] = (
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
)
origin_map: Dict[Type, Any] = {
dict: Dict,
list: List,
@@ -568,5 +601,23 @@ def _py_38_safe_origin(origin: Type) -> Type:
collections.abc.Mapping: typing.Mapping,
collections.abc.Sequence: typing.Sequence,
collections.abc.MutableMapping: typing.MutableMapping,
**origin_union_type_map,
}
return cast(Type, origin_map.get(origin, origin))
def _recursive_set_additional_properties_false(
schema: Dict[str, Any],
) -> Dict[str, Any]:
if isinstance(schema, dict):
# Check if 'required' is a key at the current level
if "required" in schema:
schema["additionalProperties"] = False
# Recursively check 'properties' and 'items' if they exist
if "properties" in schema:
for value in schema["properties"].values():
_recursive_set_additional_properties_false(value)
if "items" in schema:
_recursive_set_additional_properties_false(schema["items"])
return schema

View File

@@ -182,7 +182,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
raise OutputParserException(f"Got invalid JSON object. Error: {e}") from e
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(

View File

@@ -8,7 +8,7 @@ from langchain_core._api.deprecation import deprecated
@deprecated(
since="0.1.30",
removal="0.3",
removal="1.0",
message=(
"Using the hwchase17/langchain-hub "
"repo for prompts is deprecated. Please use "
@@ -21,7 +21,9 @@ def try_load_from_hub(
) -> Any:
warnings.warn(
"Loading from the deprecated github-based Hub is no longer supported. "
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead."
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead.",
DeprecationWarning,
stacklevel=2,
)
# return None, which indicates that we shouldn't load from old hub
# and might just be a filepath for e.g. load_chain

View File

@@ -3,13 +3,17 @@ Adapted from https://github.com/noahmorrison/chevron
MIT License
"""
from __future__ import annotations
import logging
from types import MappingProxyType
from typing import (
Any,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
@@ -22,7 +26,7 @@ from typing_extensions import TypeAlias
logger = logging.getLogger(__name__)
Scopes: TypeAlias = List[Union[Literal[False, 0], Dict[str, Any]]]
Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]]
# Globals
@@ -152,8 +156,8 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
# Get the tag
try:
tag, template = template.split(r_del, 1)
except ValueError:
raise ChevronError("unclosed tag " "at line {0}".format(_CURRENT_LINE))
except ValueError as e:
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}") from e
# Find the type meaning of the first character
tag_type = tag_types.get(tag[0], "variable")
@@ -174,7 +178,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
# Otherwise we should complain
else:
raise ChevronError(
"unclosed set delimiter tag\n" "at line {0}".format(_CURRENT_LINE)
"unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}"
)
# If we might be a no html escape tag
@@ -279,18 +283,18 @@ def tokenize(
# is the same as us
try:
last_section = open_sections.pop()
except IndexError:
except IndexError as e:
raise ChevronError(
'Trying to close tag "{0}"\n'
f'Trying to close tag "{tag_key}"\n'
"Looks like it was not opened.\n"
"line {1}".format(tag_key, _CURRENT_LINE + 1)
)
f"line {_CURRENT_LINE + 1}"
) from e
if tag_key != last_section:
# Otherwise we need to complain
raise ChevronError(
'Trying to close tag "{0}"\n'
'last open tag is "{1}"\n'
"line {2}".format(tag_key, last_section, _CURRENT_LINE + 1)
f'Trying to close tag "{tag_key}"\n'
f'last open tag is "{last_section}"\n'
f"line {_CURRENT_LINE + 1}"
)
# Do the second check to see if we're a standalone
@@ -320,8 +324,8 @@ def tokenize(
# Then we need to complain
raise ChevronError(
"Unexpected EOF\n"
'the tag "{0}" was never closed\n'
"was opened at line {1}".format(open_sections[-1], _LAST_TAG_LINE)
f'the tag "{open_sections[-1]}" was never closed\n'
f"was opened at line {_LAST_TAG_LINE}"
)
@@ -403,15 +407,15 @@ def _get_key(
# We couldn't find the key in any of the scopes
if warn:
logger.warn("Could not find key '%s'" % (key))
logger.warn(f"Could not find key '{key}'")
if keep:
return "%s %s %s" % (def_ldel, key, def_rdel)
return f"{def_ldel} {key} {def_rdel}"
return ""
def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
"""Load a partial"""
try:
# Maybe the partial is in the dictionary
@@ -425,11 +429,13 @@ def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
#
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
def render(
template: Union[str, List[Tuple[str, str]]] = "",
data: Dict[str, Any] = {},
partials_dict: Dict[str, str] = {},
data: Mapping[str, Any] = EMPTY_DICT,
partials_dict: Mapping[str, str] = EMPTY_DICT,
padding: str = "",
def_ldel: str = "{{",
def_rdel: str = "}}",
@@ -565,9 +571,9 @@ def render(
if tag_type == "literal":
text += tag_key
elif tag_type == "no escape":
text += "%s& %s %s" % (def_ldel, tag_key, def_rdel)
text += f"{def_ldel}& {tag_key} {def_rdel}"
else:
text += "%s%s %s%s" % (
text += "{}{} {}{}".format(
def_ldel,
{
"comment": "!",

View File

@@ -5,11 +5,12 @@ from __future__ import annotations
import inspect
import textwrap
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
import pydantic # pydantic: ignore
from langchain_core.pydantic_v1 import BaseModel, root_validator
from pydantic import BaseModel, root_validator # pydantic: ignore
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue # pydantic: ignore
from pydantic_core import core_schema # pydantic: ignore
def get_pydantic_major_version() -> int:
@@ -26,9 +27,13 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
@@ -142,7 +147,7 @@ def pre_init(func: Callable) -> Any:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.__fields__
fields = cls.model_fields
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
@@ -151,9 +156,13 @@ def pre_init(func: Callable) -> Any:
if cls.Config.allow_population_by_field_name:
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if hasattr(cls, "model_config"):
if cls.model_config.get("populate_by_name"):
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if name not in values or values[name] is None:
if not field_info.required:
if not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
@@ -165,6 +174,44 @@ def pre_init(func: Callable) -> Any:
return wrapper
class _IgnoreUnserializable(GenerateJsonSchema):
"""A JSON schema generator that ignores unknown types.
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
"""
def handle_invalid_for_json_schema(
self, schema: core_schema.CoreSchema, error_info: str
) -> JsonSchemaValue:
return {}
def v1_repr(obj: BaseModel) -> str:
"""Return the schema of the object as a string.
Get a repr for the pydantic object which is consistent with pydantic.v1.
"""
if not is_basemodel_instance(obj):
raise TypeError(f"Expected a pydantic BaseModel, got {type(obj)}")
repr_ = []
for name, field in get_fields(obj).items():
value = getattr(obj, name)
if isinstance(value, BaseModel):
repr_.append(f"{name}={v1_repr(value)}")
else:
if not field.is_required():
if not value:
continue
if field.default == value:
continue
repr_.append(f"{name}={repr(value)}")
args = ", ".join(repr_)
return f"{obj.__class__.__name__}({args})"
def _create_subset_model_v1(
name: str,
model: Type[BaseModel],
@@ -174,12 +221,20 @@ def _create_subset_model_v1(
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
from langchain_core.pydantic_v1 import create_model
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import create_model # pydantic: ignore
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import create_model # type: ignore # pydantic: ignore
else:
raise NotImplementedError(
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
)
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
# Using pydantic v1 so can access __fields__ as a dict.
field = model.__fields__[field_name] # type: ignore
t = (
# this isn't perfect but should work for most functions
field.outer_type_
@@ -266,3 +321,48 @@ def _create_subset_model(
raise NotImplementedError(
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
)
if PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.fields import FieldInfo as FieldInfoV2
from pydantic.v1 import BaseModel as BaseModelV1
@overload
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ...
@overload
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ...
def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
Type[BaseModelV2],
Type[BaseModelV1],
],
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields # type: ignore
elif hasattr(model, "__fields__"):
return model.__fields__ # type: ignore
else:
raise TypeError(f"Expected a Pydantic model. Got {type(model)}")
elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_
def get_fields( # type: ignore[no-redef]
model: Union[Type[BaseModelV1_], BaseModelV1_],
) -> Dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")

View File

@@ -4,14 +4,15 @@ import contextlib
import datetime
import functools
import importlib
import os
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload
from packaging.version import parse
from pydantic import SecretStr
from requests import HTTPError, Response
from langchain_core.pydantic_v1 import SecretStr
from langchain_core.utils.pydantic import (
is_pydantic_v1_subclass,
)
@@ -130,12 +131,12 @@ def guard_import(
"""
try:
module = importlib.import_module(module_name, package)
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError) as e:
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
raise ImportError(
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`."
)
) from e
return module
@@ -234,7 +235,8 @@ def build_extra_kwargs(
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
Please confirm that {field_name} is what you intended.""",
stacklevel=7,
)
extra_kwargs[field_name] = values.pop(field_name)
@@ -260,3 +262,155 @@ def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
if isinstance(value, SecretStr):
return value
return SecretStr(value)
class _NoDefaultType:
"""Type to indicate no default value is provided."""
pass
_NoDefault = _NoDefaultType()
@overload
def from_env(key: str, /) -> Callable[[], str]: ...
@overload
def from_env(key: str, /, *, default: str) -> Callable[[], str]: ...
@overload
def from_env(key: Sequence[str], /, *, default: str) -> Callable[[], str]: ...
@overload
def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ...
@overload
def from_env(
key: Union[str, Sequence[str]], /, *, default: str, error_message: Optional[str]
) -> Callable[[], str]: ...
@overload
def from_env(
key: str, /, *, default: None, error_message: Optional[str]
) -> Callable[[], Optional[str]]: ...
@overload
def from_env(
key: Union[str, Sequence[str]], /, *, default: None
) -> Callable[[], Optional[str]]: ...
def from_env(
key: Union[str, Sequence[str]],
/,
*,
default: Union[str, _NoDefaultType, None] = _NoDefault,
error_message: Optional[str] = None,
) -> Union[Callable[[], str], Callable[[], Optional[str]]]:
"""Create a factory method that gets a value from an environment variable.
Args:
key: The environment variable to look up. If a list of keys is provided,
the first key found in the environment will be used.
If no key is found, the default value will be used if set,
otherwise an error will be raised.
default: The default value to return if the environment variable is not set.
error_message: the error message which will be raised if the key is not found
and no default value is provided.
This will be raised as a ValueError.
"""
def get_from_env_fn() -> Optional[str]:
"""Get a value from an environment variable."""
if isinstance(key, (list, tuple)):
for k in key:
if k in os.environ:
return os.environ[k]
if isinstance(key, str):
if key in os.environ:
return os.environ[key]
if isinstance(default, (str, type(None))):
return default
else:
if error_message:
raise ValueError(error_message)
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
return get_from_env_fn
@overload
def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ...
@overload
def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: ...
@overload
def secret_from_env(
key: Union[str, Sequence[str]], /, *, default: None
) -> Callable[[], Optional[SecretStr]]: ...
@overload
def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretStr]: ...
def secret_from_env(
key: Union[str, Sequence[str]],
/,
*,
default: Union[str, _NoDefaultType, None] = _NoDefault,
error_message: Optional[str] = None,
) -> Union[Callable[[], Optional[SecretStr]], Callable[[], SecretStr]]:
"""Secret from env.
Args:
key: The environment variable to look up.
default: The default value to return if the environment variable is not set.
error_message: the error message which will be raised if the key is not found
and no default value is provided.
This will be raised as a ValueError.
Returns:
factory method that will look up the secret from the environment.
"""
def get_secret_from_env() -> Optional[SecretStr]:
"""Get a value from an environment variable."""
if isinstance(key, (list, tuple)):
for k in key:
if k in os.environ:
return SecretStr(os.environ[k])
if isinstance(key, str):
if key in os.environ:
return SecretStr(os.environ[key])
if isinstance(default, str):
return SecretStr(default)
elif isinstance(default, type(None)):
return None
else:
if error_message:
raise ValueError(error_message)
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
return get_secret_from_env

View File

@@ -29,30 +29,24 @@ from itertools import cycle
from typing import (
TYPE_CHECKING,
Any,
AsyncIterable,
AsyncIterator,
Callable,
ClassVar,
Collection,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from langchain_core._api import beta
from pydantic import ConfigDict, Field, model_validator
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils.aiter import abatch_iterate
from langchain_core.utils.iter import batch_iterate
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@@ -60,7 +54,6 @@ if TYPE_CHECKING:
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.indexing import UpsertResponse
logger = logging.getLogger(__name__)
@@ -96,7 +89,7 @@ class VectorStore(ABC):
ValueError: If the number of metadatas does not match the number of texts.
ValueError: If the number of ids does not match the number of texts.
"""
if type(self).upsert != VectorStore.upsert:
if type(self).add_documents != VectorStore.add_documents:
# Import document in local scope to avoid circular imports
from langchain_core.documents import Document
@@ -109,190 +102,19 @@ class VectorStore(ABC):
if metadatas and len(metadatas) != len(texts_):
raise ValueError(
"The number of metadatas must match the number of texts."
"Got {len(metadatas)} metadatas and {len(texts_)} texts."
f"Got {len(metadatas)} metadatas and {len(texts_)} texts."
)
if "ids" in kwargs:
ids = kwargs.pop("ids")
if ids and len(ids) != len(texts_):
raise ValueError(
"The number of ids must match the number of texts."
"Got {len(ids)} ids and {len(texts_)} texts."
)
else:
ids = None
metadatas_ = iter(metadatas) if metadatas else cycle([{}])
ids_: Iterable[Union[str, None]] = ids if ids is not None else cycle([None])
docs = [
Document(page_content=text, metadata=metadata_, id=id_)
for text, metadata_, id_ in zip(texts, metadatas_, ids_)
Document(page_content=text, metadata=metadata_)
for text, metadata_ in zip(texts, metadatas_)
]
upsert_response = self.upsert(docs, **kwargs)
return upsert_response["succeeded"]
return self.add_documents(docs, **kwargs)
raise NotImplementedError(
f"`add_texts` has not been implemented for {self.__class__.__name__} "
)
# Developer guidelines:
# Do not override streaming_upsert!
@beta(message="Added in 0.2.11. The API is subject to change.")
def streaming_upsert(
self, items: Iterable[Document], /, batch_size: int, **kwargs: Any
) -> Iterator[UpsertResponse]:
"""Upsert documents in a streaming fashion.
Args:
items: Iterable of Documents to add to the vectorstore.
batch_size: The size of each batch to upsert.
kwargs: Additional keyword arguments.
kwargs should only include parameters that are common to all
documents. (e.g., timeout for indexing, retry policy, etc.)
kwargs should not include ids to avoid ambiguous semantics.
Instead, the ID should be provided as part of the Document object.
Yields:
UpsertResponse: A response object that contains the list of IDs that were
successfully added or updated in the vectorstore and the list of IDs that
failed to be added or updated.
.. versionadded:: 0.2.11
"""
# The default implementation of this method breaks the input into
# batches of size `batch_size` and calls the `upsert` method on each batch.
# Subclasses can override this method to provide a more efficient
# implementation.
for item_batch in batch_iterate(batch_size, items):
yield self.upsert(item_batch, **kwargs)
# Please note that we've added a new method `upsert` instead of re-using the
# existing `add_documents` method.
# This was done to resolve potential ambiguities around the behavior of **kwargs
# in existing add_documents / add_texts methods which could include per document
# information (e.g., the `ids` parameter).
# Over time the `add_documents` could be denoted as legacy and deprecated
# in favor of the `upsert` method.
@beta(message="Added in 0.2.11. The API is subject to change.")
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
"""Add or update documents in the vectorstore.
The upsert functionality should utilize the ID field of the Document object
if it is provided. If the ID is not provided, the upsert method is free
to generate an ID for the document.
When an ID is specified and the document already exists in the vectorstore,
the upsert method should update the document with the new data. If the document
does not exist, the upsert method should add the document to the vectorstore.
Args:
items: Sequence of Documents to add to the vectorstore.
kwargs: Additional keyword arguments.
Returns:
UpsertResponse: A response object that contains the list of IDs that were
successfully added or updated in the vectorstore and the list of IDs that
failed to be added or updated.
.. versionadded:: 0.2.11
"""
# Developer guidelines:
#
# Vectorstores implementations are free to extend `upsert` implementation
# to take in additional data per document.
#
# This data **SHOULD NOT** be part of the **kwargs** parameter, instead
# sub-classes can use a Union type on `documents` to include additional
# supported formats for the input data stream.
#
# For example,
#
# .. code-block:: python
# from typing import TypedDict
#
# class DocumentWithVector(TypedDict):
# document: Document
# vector: List[float]
#
# def upsert(
# self,
# documents: Union[Iterable[Document], Iterable[DocumentWithVector]],
# /,
# **kwargs
# ) -> UpsertResponse:
# \"\"\"Add or update documents in the vectorstore.\"\"\"
# # Implementation should check if documents is an
# # iterable of DocumentWithVector or Document
# pass
#
# Implementations that override upsert should include a new doc-string
# that explains the semantics of upsert and includes in code
# examples of how to insert using the alternate data formats.
# The implementation does not delegate to the `add_texts` method or
# the `add_documents` method by default since those implementations
raise NotImplementedError(
f"upsert has not been implemented for {self.__class__.__name__}"
)
@beta(message="Added in 0.2.11. The API is subject to change.")
async def astreaming_upsert(
self,
items: AsyncIterable[Document],
/,
batch_size: int,
**kwargs: Any,
) -> AsyncIterator[UpsertResponse]:
"""Upsert documents in a streaming fashion. Async version of streaming_upsert.
Args:
items: Iterable of Documents to add to the vectorstore.
batch_size: The size of each batch to upsert.
kwargs: Additional keyword arguments.
kwargs should only include parameters that are common to all
documents. (e.g., timeout for indexing, retry policy, etc.)
kwargs should not include ids to avoid ambiguous semantics.
Instead the ID should be provided as part of the Document object.
Yields:
UpsertResponse: A response object that contains the list of IDs that were
successfully added or updated in the vectorstore and the list of IDs that
failed to be added or updated.
.. versionadded:: 0.2.11
"""
async for batch in abatch_iterate(batch_size, items):
yield await self.aupsert(batch, **kwargs)
@beta(message="Added in 0.2.11. The API is subject to change.")
async def aupsert(
self, items: Sequence[Document], /, **kwargs: Any
) -> UpsertResponse:
"""Add or update documents in the vectorstore. Async version of upsert.
The upsert functionality should utilize the ID field of the Document object
if it is provided. If the ID is not provided, the upsert method is free
to generate an ID for the document.
When an ID is specified and the document already exists in the vectorstore,
the upsert method should update the document with the new data. If the document
does not exist, the upsert method should add the document to the vectorstore.
Args:
items: Sequence of Documents to add to the vectorstore.
kwargs: Additional keyword arguments.
Returns:
UpsertResponse: A response object that contains the list of IDs that were
successfully added or updated in the vectorstore and the list of IDs that
failed to be added or updated.
.. versionadded:: 0.2.11
"""
# Developer guidelines: See guidelines for the `upsert` method.
# The implementation does not delegate to the `add_texts` method or
# the `add_documents` method by default since those implementations
return await run_in_executor(None, self.upsert, items, **kwargs)
@property
def embeddings(self) -> Optional[Embeddings]:
"""Access the query embedding object if available."""
@@ -407,7 +229,7 @@ class VectorStore(ABC):
ValueError: If the number of metadatas does not match the number of texts.
ValueError: If the number of ids does not match the number of texts.
"""
if type(self).aupsert != VectorStore.aupsert:
if type(self).aadd_documents != VectorStore.aadd_documents:
# Import document in local scope to avoid circular imports
from langchain_core.documents import Document
@@ -420,27 +242,16 @@ class VectorStore(ABC):
if metadatas and len(metadatas) != len(texts_):
raise ValueError(
"The number of metadatas must match the number of texts."
"Got {len(metadatas)} metadatas and {len(texts_)} texts."
f"Got {len(metadatas)} metadatas and {len(texts_)} texts."
)
if "ids" in kwargs:
ids = kwargs.pop("ids")
if ids and len(ids) != len(texts_):
raise ValueError(
"The number of ids must match the number of texts."
"Got {len(ids)} ids and {len(texts_)} texts."
)
else:
ids = None
metadatas_ = iter(metadatas) if metadatas else cycle([{}])
ids_: Iterable[Union[str, None]] = ids if ids is not None else cycle([None])
docs = [
Document(page_content=text, metadata=metadata_, id=id_)
for text, metadata_, id_ in zip(texts, metadatas_, ids_)
Document(page_content=text, metadata=metadata_)
for text, metadata_ in zip(texts, metadatas_)
]
upsert_response = await self.aupsert(docs, **kwargs)
return upsert_response["succeeded"]
return await self.aadd_documents(docs, **kwargs)
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
@@ -458,37 +269,22 @@ class VectorStore(ABC):
Raises:
ValueError: If the number of ids does not match the number of documents.
"""
if type(self).upsert != VectorStore.upsert:
from langchain_core.documents import Document
if type(self).add_texts != VectorStore.add_texts:
if "ids" not in kwargs:
ids = [doc.id for doc in documents]
if "ids" in kwargs:
ids = kwargs.pop("ids")
if ids and len(ids) != len(documents):
raise ValueError(
"The number of ids must match the number of documents. "
"Got {len(ids)} ids and {len(documents)} documents."
)
# If there's at least one valid ID, we'll assume that IDs
# should be used.
if any(ids):
kwargs["ids"] = ids
documents_ = []
for id_, document in zip(ids, documents):
doc_with_id = Document(
page_content=document.page_content,
metadata=document.metadata,
id=id_,
)
documents_.append(doc_with_id)
else:
documents_ = documents
# If upsert has been implemented, we can use it to add documents
return self.upsert(documents_, **kwargs)["succeeded"]
# Code path that delegates to add_text for backwards compatibility
# TODO: Handle the case where the user doesn't provide ids on the Collection
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.add_texts(texts, metadatas, **kwargs)
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.add_texts(texts, metadatas, **kwargs)
raise NotImplementedError(
f"`add_documents` and `add_texts` has not been implemented "
f"for {self.__class__.__name__} "
)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
@@ -506,41 +302,21 @@ class VectorStore(ABC):
Raises:
ValueError: If the number of IDs does not match the number of documents.
"""
# If either upsert or aupsert has been implemented, we delegate to them!
if (
type(self).aupsert != VectorStore.aupsert
or type(self).upsert != VectorStore.upsert
):
# If aupsert has been implemented, we can use it to add documents
from langchain_core.documents import Document
# If the async method has been overridden, we'll use that.
if type(self).aadd_texts != VectorStore.aadd_texts:
if "ids" not in kwargs:
ids = [doc.id for doc in documents]
if "ids" in kwargs:
ids = kwargs.pop("ids")
if ids and len(ids) != len(documents):
raise ValueError(
"The number of ids must match the number of documents."
"Got {len(ids)} ids and {len(documents)} documents."
)
# If there's at least one valid ID, we'll assume that IDs
# should be used.
if any(ids):
kwargs["ids"] = ids
documents_ = []
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return await self.aadd_texts(texts, metadatas, **kwargs)
for id_, document in zip(ids, documents):
doc_with_id = Document(
page_content=document.page_content,
metadata=document.metadata,
id=id_,
)
documents_.append(doc_with_id)
else:
documents_ = documents
# The default implementation of aupsert delegates to upsert.
upsert_response = await self.aupsert(documents_, **kwargs)
return upsert_response["succeeded"]
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return await self.aadd_texts(texts, metadatas, **kwargs)
return await run_in_executor(None, self.add_documents, documents, **kwargs)
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
"""Return docs most similar to query using a specified search type.
@@ -783,7 +559,8 @@ class VectorStore(ABC):
):
warnings.warn(
"Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}"
f" 0 and 1, got {docs_and_similarities}",
stacklevel=2,
)
if score_threshold is not None:
@@ -793,7 +570,7 @@ class VectorStore(ABC):
if similarity >= score_threshold
]
if len(docs_and_similarities) == 0:
warnings.warn(
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
@@ -830,7 +607,8 @@ class VectorStore(ABC):
):
warnings.warn(
"Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}"
f" 0 and 1, got {docs_and_similarities}",
stacklevel=2,
)
if score_threshold is not None:
@@ -840,7 +618,7 @@ class VectorStore(ABC):
if similarity >= score_threshold
]
if len(docs_and_similarities) == 0:
warnings.warn(
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
@@ -1207,11 +985,13 @@ class VectorStoreRetriever(BaseRetriever):
"mmr",
)
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@root_validator(pre=True)
def validate_search_type(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def validate_search_type(cls, values: Dict) -> Any:
"""Validate search type.
Args:
@@ -1239,6 +1019,25 @@ class VectorStoreRetriever(BaseRetriever):
)
return values
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing."""
ls_params = super()._get_ls_params(**kwargs)
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__
if self.vectorstore.embeddings:
ls_params["ls_embedding_provider"] = (
self.vectorstore.embeddings.__class__.__name__
)
elif hasattr(self.vectorstore, "embedding") and isinstance(
self.vectorstore.embedding, Embeddings
):
ls_params["ls_embedding_provider"] = (
self.vectorstore.embedding.__class__.__name__
)
return ls_params
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:

View File

@@ -8,30 +8,142 @@ from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
)
from langchain_core._api import deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.load import dumpd, load
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores.utils import _cosine_similarity as cosine_similarity
from langchain_core.vectorstores.utils import (
_maximal_marginal_relevance as maximal_marginal_relevance,
)
from langchain_core.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from langchain_core.indexing import UpsertResponse
class InMemoryVectorStore(VectorStore):
"""In-memory implementation of VectorStore using a dictionary.
"""In-memory vector store implementation.
Uses numpy to compute cosine similarity for search.
"""
Uses a dictionary, and computes cosine similarity for search using numpy.
Setup:
Install ``langchain-core``.
.. code-block:: bash
pip install -U langchain-core
Key init args — indexing params:
embedding_function: Embeddings
Embedding function to use.
Instantiate:
.. code-block:: python
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings
vector_store = InMemoryVectorStore(OpenAIEmbeddings())
Add Documents:
.. code-block:: python
from langchain_core.documents import Document
document_1 = Document(id="1", page_content="foo", metadata={"baz": "bar"})
document_2 = Document(id="2", page_content="thud", metadata={"bar": "baz"})
document_3 = Document(id="3", page_content="i will be deleted :(")
documents = [document_1, document_2, document_3]
vector_store.add_documents(documents=documents)
Delete Documents:
.. code-block:: python
vector_store.delete(ids=["3"])
Search:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1)
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: none
* thud [{'bar': 'baz'}]
Search with filter:
.. code-block:: python
def _filter_function(doc: Document) -> bool:
return doc.metadata.get("bar") == "baz"
results = vector_store.similarity_search(
query="thud", k=1, filter=_filter_function
)
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: none
* thud [{'bar': 'baz'}]
Search with score:
.. code-block:: python
results = vector_store.similarity_search_with_score(
query="qux", k=1
)
for doc, score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: none
* [SIM=0.832268] foo [{'baz': 'bar'}]
Async:
.. code-block:: python
# add documents
# await vector_store.aadd_documents(documents=documents)
# delete documents
# await vector_store.adelete(ids=["3"])
# search
# results = vector_store.asimilarity_search(query="thud", k=1)
# search with score
results = await vector_store.asimilarity_search_with_score(query="qux", k=1)
for doc,score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: none
* [SIM=0.832268] foo [{'baz': 'bar'}]
Use as Retriever:
.. code-block:: python
retriever = vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
)
retriever.invoke("thud")
.. code-block:: none
[Document(id='2', metadata={'bar': 'baz'}, page_content='thud')]
""" # noqa: E501
def __init__(self, embedding: Embeddings) -> None:
"""Initialize with the given embedding function.
@@ -56,43 +168,71 @@ class InMemoryVectorStore(VectorStore):
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
self.delete(ids)
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
vectors = self.embedding.embed_documents([item.page_content for item in items])
ids = []
for item, vector in zip(items, vectors):
doc_id = item.id if item.id else str(uuid.uuid4())
ids.append(doc_id)
self.store[doc_id] = {
"id": doc_id,
"vector": vector,
"text": item.page_content,
"metadata": item.metadata,
}
return {
"succeeded": ids,
"failed": [],
}
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Add documents to the store."""
texts = [doc.page_content for doc in documents]
vectors = self.embedding.embed_documents(texts)
async def aupsert(
self, items: Sequence[Document], /, **kwargs: Any
) -> UpsertResponse:
vectors = await self.embedding.aembed_documents(
[item.page_content for item in items]
if ids and len(ids) != len(texts):
raise ValueError(
f"ids must be the same length as texts. "
f"Got {len(ids)} ids and {len(texts)} texts."
)
id_iterator: Iterator[Optional[str]] = (
iter(ids) if ids else iter(doc.id for doc in documents)
)
ids = []
for item, vector in zip(items, vectors):
doc_id = item.id if item.id else str(uuid.uuid4())
ids.append(doc_id)
self.store[doc_id] = {
"id": doc_id,
ids_ = []
for doc, vector in zip(documents, vectors):
doc_id = next(id_iterator)
doc_id_ = doc_id if doc_id else str(uuid.uuid4())
ids_.append(doc_id_)
self.store[doc_id_] = {
"id": doc_id_,
"vector": vector,
"text": item.page_content,
"metadata": item.metadata,
"text": doc.page_content,
"metadata": doc.metadata,
}
return {
"succeeded": ids,
"failed": [],
}
return ids_
async def aadd_documents(
self, documents: List[Document], ids: Optional[List[str]] = None, **kwargs: Any
) -> List[str]:
"""Add documents to the store."""
texts = [doc.page_content for doc in documents]
vectors = await self.embedding.aembed_documents(texts)
if ids and len(ids) != len(texts):
raise ValueError(
f"ids must be the same length as texts. "
f"Got {len(ids)} ids and {len(texts)} texts."
)
id_iterator: Iterator[Optional[str]] = (
iter(ids) if ids else iter(doc.id for doc in documents)
)
ids_: List[str] = []
for doc, vector in zip(documents, vectors):
doc_id = next(id_iterator)
doc_id_ = doc_id if doc_id else str(uuid.uuid4())
ids_.append(doc_id_)
self.store[doc_id_] = {
"id": doc_id_,
"vector": vector,
"text": doc.page_content,
"metadata": doc.metadata,
}
return ids_
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
"""Get documents by their ids.
@@ -117,6 +257,62 @@ class InMemoryVectorStore(VectorStore):
)
return documents
@deprecated(
alternative="VectorStore.add_documents",
message=(
"This was a beta API that was added in 0.2.11. "
"It'll be removed in 0.3.0."
),
since="0.2.29",
removal="1.0",
)
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
vectors = self.embedding.embed_documents([item.page_content for item in items])
ids = []
for item, vector in zip(items, vectors):
doc_id = item.id if item.id else str(uuid.uuid4())
ids.append(doc_id)
self.store[doc_id] = {
"id": doc_id,
"vector": vector,
"text": item.page_content,
"metadata": item.metadata,
}
return {
"succeeded": ids,
"failed": [],
}
@deprecated(
alternative="VectorStore.aadd_documents",
message=(
"This was a beta API that was added in 0.2.11. "
"It'll be removed in 0.3.0."
),
since="0.2.29",
removal="1.0",
)
async def aupsert(
self, items: Sequence[Document], /, **kwargs: Any
) -> UpsertResponse:
vectors = await self.embedding.aembed_documents(
[item.page_content for item in items]
)
ids = []
for item, vector in zip(items, vectors):
doc_id = item.id if item.id else str(uuid.uuid4())
ids.append(doc_id)
self.store[doc_id] = {
"id": doc_id,
"vector": vector,
"text": item.page_content,
"metadata": item.metadata,
}
return {
"succeeded": ids,
"failed": [],
}
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
"""Async get documents by their ids.
@@ -239,11 +435,11 @@ class InMemoryVectorStore(VectorStore):
try:
import numpy as np
except ImportError:
except ImportError as e:
raise ImportError(
"numpy must be installed to use max_marginal_relevance_search "
"pip install numpy"
)
) from e
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
@@ -294,7 +490,7 @@ class InMemoryVectorStore(VectorStore):
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "InMemoryVectorStore":
) -> InMemoryVectorStore:
store = cls(
embedding=embedding,
)
@@ -308,7 +504,7 @@ class InMemoryVectorStore(VectorStore):
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "InMemoryVectorStore":
) -> InMemoryVectorStore:
store = cls(
embedding=embedding,
)
@@ -318,7 +514,7 @@ class InMemoryVectorStore(VectorStore):
@classmethod
def load(
cls, path: str, embedding: Embeddings, **kwargs: Any
) -> "InMemoryVectorStore":
) -> InMemoryVectorStore:
"""Load a vector store from a file.
Args:

View File

@@ -34,11 +34,11 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""
try:
import numpy as np
except ImportError:
except ImportError as e:
raise ImportError(
"cosine_similarity requires numpy to be installed. "
"Please install numpy with `pip install numpy`."
)
) from e
if len(X) == 0 or len(Y) == 0:
return np.array([])
@@ -51,7 +51,7 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
f"and Y has shape {Y.shape}."
)
try:
import simsimd as simd # type: ignore
import simsimd as simd # type: ignore[import-not-found]
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)
@@ -71,7 +71,7 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
return similarity
def _maximal_marginal_relevance(
def maximal_marginal_relevance(
query_embedding: np.ndarray,
embedding_list: list,
lambda_mult: float = 0.5,
@@ -93,11 +93,11 @@ def _maximal_marginal_relevance(
"""
try:
import numpy as np
except ImportError:
except ImportError as e:
raise ImportError(
"maximal_marginal_relevance requires numpy to be installed. "
"Please install numpy with `pip install numpy`."
)
) from e
if min(k, len(embedding_list)) <= 0:
return []

1168
libs/core/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-core"
version = "0.2.29rc1"
version = "0.2.37"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"
@@ -41,7 +41,8 @@ python = ">=3.12.4"
[tool.poetry.extras]
[tool.ruff.lint]
select = [ "E", "F", "I", "T201",]
select = [ "B", "E", "F", "I", "T201", "UP",]
ignore = [ "UP006", "UP007",]
[tool.coverage.run]
omit = [ "tests/*",]
@@ -79,6 +80,7 @@ mypy = ">=1.10,<1.11"
types-pyyaml = "^6.0.12.2"
types-requests = "^2.28.11.5"
types-jinja2 = "^2.11.9"
simsimd = "^5.0.0"
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"

View File

@@ -3,9 +3,9 @@ import warnings
from typing import Any, Dict
import pytest
from pydantic import BaseModel
from langchain_core._api.beta_decorator import beta, warn_beta
from langchain_core.pydantic_v1 import BaseModel
@pytest.mark.parametrize(

View File

@@ -3,9 +3,13 @@ import warnings
from typing import Any, Dict
import pytest
from pydantic import BaseModel
from langchain_core._api.deprecation import deprecated, warn_deprecated
from langchain_core.pydantic_v1 import BaseModel
from langchain_core._api.deprecation import (
deprecated,
rename_parameter,
warn_deprecated,
)
@pytest.mark.parametrize(
@@ -412,3 +416,84 @@ def test_raise_error_for_bad_decorator() -> None:
def deprecated_function() -> str:
"""original doc"""
return "This is a deprecated function."
def test_rename_parameter() -> None:
"""Test rename parameter."""
@rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name")
def foo(new_name: str) -> str:
"""original doc"""
return new_name
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
assert foo(old_name="hello") == "hello" # type: ignore[call-arg]
assert len(warning_list) == 1
assert foo(new_name="hello") == "hello"
assert foo("hello") == "hello"
assert foo.__doc__ == "original doc"
with pytest.raises(TypeError):
foo(meow="hello") # type: ignore[call-arg]
with pytest.raises(TypeError):
assert foo("hello", old_name="hello") # type: ignore[call-arg]
with pytest.raises(TypeError):
assert foo(old_name="goodbye", new_name="hello") # type: ignore[call-arg]
async def test_rename_parameter_for_async_func() -> None:
"""Test rename parameter."""
@rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name")
async def foo(new_name: str) -> str:
"""original doc"""
return new_name
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
assert await foo(old_name="hello") == "hello" # type: ignore[call-arg]
assert len(warning_list) == 1
assert await foo(new_name="hello") == "hello"
assert await foo("hello") == "hello"
assert foo.__doc__ == "original doc"
with pytest.raises(TypeError):
await foo(meow="hello") # type: ignore[call-arg]
with pytest.raises(TypeError):
assert await foo("hello", old_name="hello") # type: ignore[call-arg]
with pytest.raises(TypeError):
assert await foo(old_name="a", new_name="hello") # type: ignore[call-arg]
def test_rename_parameter_method() -> None:
"""Test that it works for a method."""
class Foo:
@rename_parameter(
since="2.0.0", removal="3.0.0", old="old_name", new="new_name"
)
def a(self, new_name: str) -> str:
return new_name
foo = Foo()
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
assert foo.a(old_name="hello") == "hello" # type: ignore[call-arg]
assert len(warning_list) == 1
assert str(warning_list[0].message) == (
"The parameter `old_name` of `a` was deprecated in 2.0.0 and will be "
"removed "
"in 3.0.0 Use `new_name` instead."
)
assert foo.a(new_name="hello") == "hello"
assert foo.a("hello") == "hello"
with pytest.raises(TypeError):
foo.a(meow="hello") # type: ignore[call-arg]
with pytest.raises(TypeError):
assert foo.a("hello", old_name="hello") # type: ignore[call-arg]

View File

@@ -4,9 +4,10 @@ from itertools import chain
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from pydantic import BaseModel
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel
class BaseFakeCallbackHandler(BaseModel):
@@ -256,7 +257,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
return self
@@ -390,5 +392,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_text_common()
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
return self

View File

@@ -9,7 +9,6 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
AnyStr,
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
@@ -70,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
]
assert len({chunk.id for chunk in chunks}) == 1
@@ -89,29 +88,23 @@ async def test_generic_fake_chat_model_stream() -> None:
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id=AnyStr(),
_AnyIdAIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'},
},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={"function_call": {"arguments": ","}},
id=AnyStr(),
_AnyIdAIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
},
id=AnyStr(),
),
]
assert len({chunk.id for chunk in chunks}) == 1

View File

@@ -1,12 +1,13 @@
"""Test base chat model."""
import uuid
from typing import Any, AsyncIterator, Iterator, List, Optional
from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union
import pytest
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, FakeListChatModel
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@@ -18,6 +19,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import collect_runs
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
from langchain_core.tracers.schemas import Run
from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
@@ -109,7 +111,7 @@ async def test_stream_error_callback() -> None:
responses=[message],
error_on_chunk_number=i,
)
with pytest.raises(Exception):
with pytest.raises(FakeListChatModelError):
cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass
@@ -272,3 +274,96 @@ async def test_async_pass_run_id() -> None:
uid3 = uuid.uuid4()
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
assert cb.traced_run_ids == [uid1, uid2, uid3]
class NoStreamingModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))])
@property
def _llm_type(self) -> str:
return "model1"
class StreamingModel(NoStreamingModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
yield ChatGenerationChunk(message=AIMessageChunk(content="stream"))
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
expected = "invoke" if disable_streaming is True else "stream"
assert next(model.stream([])).content == expected
async for c in model.astream([]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
== expected
)
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == expected
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
async for c in model.astream([], tools=[{}]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
).content
== expected
)
assert (
await model.ainvoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
)
).content == expected
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
assert next(model.stream([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
== "invoke"
)
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == "invoke"
assert next(model.stream([], tools=[{}])).content == "invoke"
async for c in model.astream([], tools=[{}]):
assert c.content == "invoke"
break

View File

@@ -1,4 +1,5 @@
import time
from typing import Optional as Optional
from langchain_core.caches import InMemoryCache
from langchain_core.language_models import GenericFakeChatModel
@@ -220,6 +221,9 @@ class SerializableModel(GenericFakeChatModel):
return True
SerializableModel.model_rebuild()
def test_serialization_with_rate_limiter() -> None:
"""Test model serialization with rate limiter."""
from langchain_core.load import dumps

View File

@@ -7,6 +7,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
from langchain_core.language_models.fake import FakeListLLMError
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
@@ -108,7 +109,7 @@ async def test_stream_error_callback() -> None:
responses=[message],
error_on_chunk_number=i,
)
with pytest.raises(Exception):
with pytest.raises(FakeListLLMError):
cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass

View File

@@ -6,6 +6,7 @@ EXPECTED_ALL = [
"SimpleChatModel",
"BaseLLM",
"LLM",
"LangSmithParams",
"LanguageModelInput",
"LanguageModelOutput",
"LanguageModelLike",

Some files were not shown because too many files have changed in this diff Show More