mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-12 12:11:34 +00:00
Compare commits
97 Commits
sr/valid-c
...
eugene/cor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3981cda448 | ||
|
|
979505eb03 | ||
|
|
1882a139b7 | ||
|
|
627757e808 | ||
|
|
06da7da547 | ||
|
|
3d561c3e6d | ||
|
|
09f9d3e972 | ||
|
|
c1e6e7d020 | ||
|
|
6f79443ab5 | ||
|
|
58c4e1ef86 | ||
|
|
de30b04f37 | ||
|
|
f7a455299e | ||
|
|
76043abd47 | ||
|
|
61cdb9ccce | ||
|
|
1ef3fa54fc | ||
|
|
3856e3b02a | ||
|
|
035f09f20d | ||
|
|
63f7a5ab68 | ||
|
|
8447d9f6f1 | ||
|
|
95db9e9258 | ||
|
|
0d1b93774b | ||
|
|
bcce3a2865 | ||
|
|
4a478d82bd | ||
|
|
a0c3657442 | ||
|
|
72c5c28b4d | ||
|
|
fe6f2f724b | ||
|
|
88d347e90c | ||
|
|
741b50d4fd | ||
|
|
24c6825345 | ||
|
|
32824aa55c | ||
|
|
f6924653ea | ||
|
|
66e8594b89 | ||
|
|
3b9f061eac | ||
|
|
76b6ee290d | ||
|
|
22957311fe | ||
|
|
f9df75c8cc | ||
|
|
ece0ab8539 | ||
|
|
4ddd9e5f23 | ||
|
|
f8e95e5735 | ||
|
|
6515b2f77b | ||
|
|
63fde4f095 | ||
|
|
d9bb9125c1 | ||
|
|
384d9f59a3 | ||
|
|
fc0fa7e8f0 | ||
|
|
a1054d06ca | ||
|
|
c2570a7a7c | ||
|
|
97f4128bfd | ||
|
|
2434dc8f92 | ||
|
|
123d61a888 | ||
|
|
53f6f4a0c0 | ||
|
|
550bef230a | ||
|
|
5a998d36b2 | ||
|
|
72cd199efc | ||
|
|
a1d993deb1 | ||
|
|
e546e21d53 | ||
|
|
26d6426156 | ||
|
|
8dffedebd6 | ||
|
|
60adf8d6e4 | ||
|
|
d13a1ad5f5 | ||
|
|
1e5f8a494a | ||
|
|
5216131769 | ||
|
|
8bdaf858b8 | ||
|
|
c37a0ca672 | ||
|
|
266cd15511 | ||
|
|
9debf8144e | ||
|
|
78ce0ed337 | ||
|
|
4aa1932bea | ||
|
|
b658295b97 | ||
|
|
8c59b6a026 | ||
|
|
e35b43a7a7 | ||
|
|
7288d914a8 | ||
|
|
1b487e261a | ||
|
|
3934663db9 | ||
|
|
fb639cb49c | ||
|
|
1856387e9e | ||
|
|
a5ad775a90 | ||
|
|
a321401683 | ||
|
|
8839220a00 | ||
|
|
e6b2ca4da3 | ||
|
|
d0c52d1dec | ||
|
|
a5fa6d1c43 | ||
|
|
7f79bd6e04 | ||
|
|
339985e39e | ||
|
|
f4ecd749d5 | ||
|
|
cb61c6b4bf | ||
|
|
b42c2c6cd6 | ||
|
|
da6633bf0d | ||
|
|
0193d18bec | ||
|
|
0a82192e36 | ||
|
|
202f6fef95 | ||
|
|
c49416e908 | ||
|
|
ec93ea6240 | ||
|
|
add20dc9a8 | ||
|
|
7799474746 | ||
|
|
d98c1f115f | ||
|
|
d97f70def4 | ||
|
|
609c6b0963 |
@@ -13,7 +13,7 @@ tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -vv tests/unit_tests
|
||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
||||
|
||||
test_profile:
|
||||
poetry run pytest -vv tests/unit_tests/ --profile-svg
|
||||
@@ -39,17 +39,14 @@ lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
./scripts/check_pydantic.sh .
|
||||
./scripts/lint_imports.sh
|
||||
poetry run ruff check .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff check --select I --fix $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
@@ -1,3 +1,13 @@
|
||||
"""``langchain-core`` defines the base abstractions for the LangChain ecosystem.
|
||||
|
||||
The interfaces for core components like chat models, LLMs, vector stores, retrievers,
|
||||
and more are defined here. The universal invocation protocol (Runnables) along with
|
||||
a syntax for combining components (LangChain Expression Language) are also defined here.
|
||||
|
||||
No third-party integrations are defined here. The dependencies are kept purposefully
|
||||
very lightweight.
|
||||
"""
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_core._api import (
|
||||
|
||||
@@ -270,7 +270,7 @@ def warn_beta(
|
||||
message += f" {addendum}"
|
||||
|
||||
warning = LangChainBetaWarning(message)
|
||||
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=2)
|
||||
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=4)
|
||||
|
||||
|
||||
def surface_langchain_beta_warnings() -> None:
|
||||
|
||||
@@ -14,7 +14,17 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from langchain_core._api.internal import is_caller_internal
|
||||
|
||||
@@ -30,7 +40,8 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Union[Type, Callable[..., Any]])
|
||||
# Last Any should be FieldInfoV1 but this leads to circular imports
|
||||
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
|
||||
|
||||
|
||||
def _validate_deprecation_params(
|
||||
@@ -133,6 +144,7 @@ def deprecated(
|
||||
_package: str = package,
|
||||
) -> T:
|
||||
"""Implementation of the decorator returned by `deprecated`."""
|
||||
from langchain_core.utils.pydantic import FieldInfoV1
|
||||
|
||||
def emit_warning() -> None:
|
||||
"""Emit the warning."""
|
||||
@@ -207,50 +219,73 @@ def deprecated(
|
||||
)
|
||||
return cast(T, obj)
|
||||
|
||||
elif isinstance(obj, FieldInfoV1):
|
||||
wrapped = None
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
if not _name:
|
||||
raise ValueError(f"Field {obj} must have a name to be deprecated.")
|
||||
old_doc = obj.description
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
||||
return cast(
|
||||
T,
|
||||
FieldInfoV1(
|
||||
default=obj.default,
|
||||
default_factory=obj.default_factory,
|
||||
description=new_doc,
|
||||
alias=obj.alias,
|
||||
exclude=obj.exclude,
|
||||
),
|
||||
)
|
||||
|
||||
elif isinstance(obj, property):
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
wrapped = None
|
||||
_name = _name or obj.fget.__qualname__
|
||||
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _deprecated_property(property):
|
||||
"""A deprecated property."""
|
||||
|
||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
|
||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
|
||||
super().__init__(fget, fset, fdel, doc)
|
||||
self.__orig_fget = fget
|
||||
self.__orig_fset = fset
|
||||
self.__orig_fdel = fdel
|
||||
|
||||
def __get__(self, instance, owner=None):
|
||||
def __get__(self, instance, owner=None): # type: ignore[no-untyped-def]
|
||||
if instance is not None or owner is not None:
|
||||
emit_warning()
|
||||
return self.fget(instance)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
def __set__(self, instance, value): # type: ignore[no-untyped-def]
|
||||
if instance is not None:
|
||||
emit_warning()
|
||||
return self.fset(instance, value)
|
||||
|
||||
def __delete__(self, instance):
|
||||
def __delete__(self, instance): # type: ignore[no-untyped-def]
|
||||
if instance is not None:
|
||||
emit_warning()
|
||||
return self.fdel(instance)
|
||||
|
||||
def __set_name__(self, owner, set_name):
|
||||
def __set_name__(self, owner, set_name): # type: ignore[no-untyped-def]
|
||||
nonlocal _name
|
||||
if _name == "<lambda>":
|
||||
_name = set_name
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
||||
"""Finalize the property."""
|
||||
return _deprecated_property(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
return cast(
|
||||
T,
|
||||
_deprecated_property(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
_name = _name or obj.__qualname__
|
||||
_name = _name or cast(Union[Type, Callable], obj).__qualname__
|
||||
if not _obj_type:
|
||||
# edge case: when a function is within another function
|
||||
# within a test, this will call it a "method" not a "function"
|
||||
@@ -409,7 +444,7 @@ def warn_deprecated(
|
||||
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
|
||||
)
|
||||
warning = warning_cls(message)
|
||||
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
|
||||
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=4)
|
||||
|
||||
|
||||
def surface_langchain_deprecation_warnings() -> None:
|
||||
@@ -423,3 +458,51 @@ def surface_langchain_deprecation_warnings() -> None:
|
||||
"default",
|
||||
category=LangChainDeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def rename_parameter(
|
||||
*,
|
||||
since: str,
|
||||
removal: str,
|
||||
old: str,
|
||||
new: str,
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""Decorator indicating that parameter *old* of *func* is renamed to *new*.
|
||||
|
||||
The actual implementation of *func* should use *new*, not *old*. If *old*
|
||||
is passed to *func*, a DeprecationWarning is emitted, and its value is
|
||||
used, even if *new* is also passed by keyword.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@_api.rename_parameter("3.1", "bad_name", "good_name")
|
||||
def func(good_name): ...
|
||||
"""
|
||||
|
||||
def decorator(f: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if new in kwargs and old in kwargs:
|
||||
raise TypeError(
|
||||
f"{f.__name__}() got multiple values for argument {new!r}"
|
||||
)
|
||||
if old in kwargs:
|
||||
warn_deprecated(
|
||||
since,
|
||||
removal=removal,
|
||||
message=f"The parameter `{old}` of `{f.__name__}` was "
|
||||
f"deprecated in {since} and will be removed "
|
||||
f"in {removal} Use `{new}` instead.",
|
||||
)
|
||||
kwargs[new] = kwargs.pop(old)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
@@ -229,8 +231,9 @@ class ContextSet(RunnableSerializable):
|
||||
|
||||
keys: Mapping[str, Optional[Runnable]]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
@@ -13,6 +14,8 @@ if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
"""Mixin for Retriever callbacks."""
|
||||
@@ -911,15 +914,67 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
def copy(self: T) -> T:
|
||||
"""Copy the callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
handlers=self.handlers.copy(),
|
||||
inheritable_handlers=self.inheritable_handlers.copy(),
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tags=self.tags.copy(),
|
||||
inheritable_tags=self.inheritable_tags.copy(),
|
||||
metadata=self.metadata.copy(),
|
||||
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||
)
|
||||
|
||||
def merge(self: T, other: BaseCallbackManager) -> T:
|
||||
"""Merge the callback manager with another callback manager.
|
||||
|
||||
May be overwritten in subclasses. Primarily used internally
|
||||
within merge_configs.
|
||||
|
||||
Returns:
|
||||
BaseCallbackManager: The merged callback manager of the same type
|
||||
as the current object.
|
||||
|
||||
Example: Merging two callback managers.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
|
||||
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
|
||||
with trace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
|
||||
merged_manager = group_manager.merge(manager)
|
||||
print(merged_manager.handlers)
|
||||
# [
|
||||
# <langchain_core.callbacks.stdout.StdOutCallbackHandler object at ...>,
|
||||
# <langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler object at ...>,
|
||||
# ]
|
||||
|
||||
print(merged_manager.tags)
|
||||
# ['tag2', 'tag1']
|
||||
|
||||
""" # noqa: E501
|
||||
manager = self.__class__(
|
||||
parent_run_id=self.parent_run_id or other.parent_run_id,
|
||||
handlers=[],
|
||||
inheritable_handlers=[],
|
||||
tags=list(set(self.tags + other.tags)),
|
||||
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
|
||||
metadata={
|
||||
**self.metadata,
|
||||
**other.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
handlers = self.handlers + other.handlers
|
||||
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
|
||||
|
||||
for handler in handlers:
|
||||
manager.add_handler(handler)
|
||||
|
||||
for handler in inheritable_handlers:
|
||||
manager.add_handler(handler, inherit=True)
|
||||
return manager
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
|
||||
@@ -1612,16 +1612,75 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
def copy(self) -> CallbackManagerForChainGroup:
|
||||
"""Copy the callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
handlers=self.handlers.copy(),
|
||||
inheritable_handlers=self.inheritable_handlers.copy(),
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tags=self.tags.copy(),
|
||||
inheritable_tags=self.inheritable_tags.copy(),
|
||||
metadata=self.metadata.copy(),
|
||||
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||
parent_run_manager=self.parent_run_manager,
|
||||
)
|
||||
|
||||
def merge(
|
||||
self: CallbackManagerForChainGroup, other: BaseCallbackManager
|
||||
) -> CallbackManagerForChainGroup:
|
||||
"""Merge the group callback manager with another callback manager.
|
||||
|
||||
Overwrites the merge method in the base class to ensure that the
|
||||
parent run manager is preserved. Keeps the parent_run_manager
|
||||
from the current object.
|
||||
|
||||
Returns:
|
||||
CallbackManagerForChainGroup: A copy of the current object with the
|
||||
handlers, tags, and other attributes merged from the other object.
|
||||
|
||||
Example: Merging two callback managers.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
|
||||
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
|
||||
with trace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
|
||||
merged_manager = group_manager.merge(manager)
|
||||
print(type(merged_manager))
|
||||
# <class 'langchain_core.callbacks.manager.CallbackManagerForChainGroup'>
|
||||
|
||||
print(merged_manager.handlers)
|
||||
# [
|
||||
# <langchain_core.callbacks.stdout.LangChainTracer object at ...>,
|
||||
# <langchain_core.callbacks.streaming_stdout.StdOutCallbackHandler object at ...>,
|
||||
# ]
|
||||
|
||||
print(merged_manager.tags)
|
||||
# ['tag2', 'tag1']
|
||||
|
||||
""" # noqa: E501
|
||||
manager = self.__class__(
|
||||
parent_run_id=self.parent_run_id or other.parent_run_id,
|
||||
handlers=[],
|
||||
inheritable_handlers=[],
|
||||
tags=list(set(self.tags + other.tags)),
|
||||
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
|
||||
metadata={
|
||||
**self.metadata,
|
||||
**other.metadata,
|
||||
},
|
||||
parent_run_manager=self.parent_run_manager,
|
||||
)
|
||||
|
||||
handlers = self.handlers + other.handlers
|
||||
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
|
||||
|
||||
for handler in handlers:
|
||||
manager.add_handler(handler)
|
||||
|
||||
for handler in inheritable_handlers:
|
||||
manager.add_handler(handler, inherit=True)
|
||||
return manager
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
@@ -2040,16 +2099,75 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
def copy(self) -> AsyncCallbackManagerForChainGroup:
|
||||
"""Copy the async callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
handlers=self.handlers.copy(),
|
||||
inheritable_handlers=self.inheritable_handlers.copy(),
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tags=self.tags.copy(),
|
||||
inheritable_tags=self.inheritable_tags.copy(),
|
||||
metadata=self.metadata.copy(),
|
||||
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||
parent_run_manager=self.parent_run_manager,
|
||||
)
|
||||
|
||||
def merge(
|
||||
self: AsyncCallbackManagerForChainGroup, other: BaseCallbackManager
|
||||
) -> AsyncCallbackManagerForChainGroup:
|
||||
"""Merge the group callback manager with another callback manager.
|
||||
|
||||
Overwrites the merge method in the base class to ensure that the
|
||||
parent run manager is preserved. Keeps the parent_run_manager
|
||||
from the current object.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManagerForChainGroup: A copy of the current AsyncCallbackManagerForChainGroup
|
||||
with the handlers, tags, etc. of the other callback manager merged in.
|
||||
|
||||
Example: Merging two callback managers.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManager, atrace_as_chain_group
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
|
||||
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
|
||||
async with atrace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
|
||||
merged_manager = group_manager.merge(manager)
|
||||
print(type(merged_manager))
|
||||
# <class 'langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup'>
|
||||
|
||||
print(merged_manager.handlers)
|
||||
# [
|
||||
# <langchain_core.callbacks.stdout.LangChainTracer object at ...>,
|
||||
# <langchain_core.callbacks.streaming_stdout.StdOutCallbackHandler object at ...>,
|
||||
# ]
|
||||
|
||||
print(merged_manager.tags)
|
||||
# ['tag2', 'tag1']
|
||||
|
||||
""" # noqa: E501
|
||||
manager = self.__class__(
|
||||
parent_run_id=self.parent_run_id or other.parent_run_id,
|
||||
handlers=[],
|
||||
inheritable_handlers=[],
|
||||
tags=list(set(self.tags + other.tags)),
|
||||
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
|
||||
metadata={
|
||||
**self.metadata,
|
||||
**other.metadata,
|
||||
},
|
||||
parent_run_manager=self.parent_run_manager,
|
||||
)
|
||||
|
||||
handlers = self.handlers + other.handlers
|
||||
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
|
||||
|
||||
for handler in handlers:
|
||||
manager.add_handler(handler)
|
||||
|
||||
for handler in inheritable_handlers:
|
||||
manager.add_handler(handler, inherit=True)
|
||||
return manager
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
|
||||
@@ -20,13 +20,14 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class BaseChatMessageHistory(ABC):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader
|
||||
from langchain_core.document_loaders.blob_loaders import Blob, BlobLoader, PathLike
|
||||
from langchain_core.document_loaders.langsmith import LangSmithLoader
|
||||
|
||||
__all__ = [
|
||||
"BaseBlobParser",
|
||||
@@ -7,4 +8,5 @@ __all__ = [
|
||||
"Blob",
|
||||
"BlobLoader",
|
||||
"PathLike",
|
||||
"LangSmithLoader",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.documents.base import Blob
|
||||
|
||||
|
||||
class BaseLoader(ABC):
|
||||
class BaseLoader(ABC): # noqa: B024
|
||||
"""Interface for Document Loader.
|
||||
|
||||
Implementations should implement the lazy-loading method using generators
|
||||
|
||||
@@ -4,10 +4,12 @@ import contextlib
|
||||
import mimetypes
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Generator, List, Literal, Mapping, Optional, Union, cast
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
PathLike = Union[str, PurePath]
|
||||
|
||||
@@ -110,9 +112,10 @@ class Blob(BaseMedia):
|
||||
path: Optional[PathLike] = None
|
||||
"""Location where the original content was found."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def source(self) -> Optional[str]:
|
||||
@@ -127,8 +130,9 @@ class Blob(BaseMedia):
|
||||
return cast(Optional[str], self.metadata["source"])
|
||||
return str(self.path) if self.path else None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
@@ -137,7 +141,7 @@ class Blob(BaseMedia):
|
||||
def as_string(self) -> str:
|
||||
"""Read data as a string."""
|
||||
if self.data is None and self.path:
|
||||
with open(str(self.path), "r", encoding=self.encoding) as f:
|
||||
with open(str(self.path), encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
elif isinstance(self.data, bytes):
|
||||
return self.data.decode(self.encoding)
|
||||
@@ -293,3 +297,7 @@ class Document(BaseMedia):
|
||||
return f"page_content='{self.page_content}' metadata={self.metadata}"
|
||||
else:
|
||||
return f"page_content='{self.page_content}'"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
@@ -3,9 +3,10 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings, BaseModel):
|
||||
@@ -15,14 +16,36 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
||||
|
||||
Do not use this outside of testing, as it is not a real embedding model.
|
||||
|
||||
Example:
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.embeddings import FakeEmbeddings
|
||||
embed = FakeEmbeddings(size=100)
|
||||
|
||||
fake_embeddings = FakeEmbeddings(size=100)
|
||||
fake_embeddings.embed_documents(["hello world", "foo bar"])
|
||||
Embed single text:
|
||||
.. code-block:: python
|
||||
|
||||
input_text = "The meaning of life is 42"
|
||||
vector = embed.embed_query(input_text)
|
||||
print(vector[:3])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[-0.700234640213188, -0.581266257710429, -1.1328482266445354]
|
||||
|
||||
Embed multiple texts:
|
||||
.. code-block:: python
|
||||
|
||||
input_texts = ["Document 1...", "Document 2..."]
|
||||
vectors = embed.embed_documents(input_texts)
|
||||
print(len(vectors))
|
||||
# The first 3 coordinates for the first vector
|
||||
print(vectors[0][:3])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
2
|
||||
[-0.5670477847544458, -0.31403828652395727, -0.5840547508955257]
|
||||
"""
|
||||
|
||||
size: int
|
||||
@@ -48,14 +71,36 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
|
||||
Do not use this outside of testing, as it is not a real embedding model.
|
||||
|
||||
Example:
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
embed = DeterministicFakeEmbedding(size=100)
|
||||
|
||||
fake_embeddings = DeterministicFakeEmbedding(size=100)
|
||||
fake_embeddings.embed_documents(["hello world", "foo bar"])
|
||||
Embed single text:
|
||||
.. code-block:: python
|
||||
|
||||
input_text = "The meaning of life is 42"
|
||||
vector = embed.embed_query(input_text)
|
||||
print(vector[:3])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[-0.700234640213188, -0.581266257710429, -1.1328482266445354]
|
||||
|
||||
Embed multiple texts:
|
||||
.. code-block:: python
|
||||
|
||||
input_texts = ["Document 1...", "Document 2..."]
|
||||
vectors = embed.embed_documents(input_texts)
|
||||
print(len(vectors))
|
||||
# The first 3 coordinates for the first vector
|
||||
print(vectors[0][:3])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
2
|
||||
[-0.5670477847544458, -0.31403828652395727, -0.5840547508955257]
|
||||
"""
|
||||
|
||||
size: int
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
import re
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
||||
|
||||
|
||||
def _get_length_based(text: str) -> int:
|
||||
|
||||
@@ -5,9 +5,10 @@ from __future__ import annotations
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Extra
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -42,9 +43,10 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
vectorstore_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""Extra arguments passed to similarity_search function of the vectorstore."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra=Extra.forbid,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _example_to_text(
|
||||
|
||||
@@ -39,7 +39,7 @@ class OutputParserException(ValueError, LangChainException):
|
||||
llm_output: Optional[str] = None,
|
||||
send_to_llm: bool = False,
|
||||
):
|
||||
super(OutputParserException, self).__init__(error)
|
||||
super().__init__(error)
|
||||
if send_to_llm:
|
||||
if observation is None or llm_output is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@@ -20,7 +22,6 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
|
||||
@@ -32,15 +33,17 @@ def _has_next(iterator: Iterator) -> bool:
|
||||
return next(iterator, sentinel) is not sentinel
|
||||
|
||||
|
||||
@beta()
|
||||
class Node(Serializable):
|
||||
"""Node in the GraphVectorStore.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
|
||||
For instance two nodes `a` and `b` connected over a hyperlink `https://some-url`
|
||||
For instance two nodes `a` and `b` connected over a hyperlink ``https://some-url``
|
||||
would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
Node(
|
||||
id="a",
|
||||
@@ -79,12 +82,12 @@ def _texts_to_nodes(
|
||||
for text in texts:
|
||||
try:
|
||||
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
||||
except StopIteration:
|
||||
raise ValueError("texts iterable longer than metadatas")
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than metadatas") from e
|
||||
try:
|
||||
_id = next(ids_it) if ids_it else None
|
||||
except StopIteration:
|
||||
raise ValueError("texts iterable longer than ids")
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than ids") from e
|
||||
|
||||
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
@@ -115,7 +118,15 @@ def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
|
||||
)
|
||||
|
||||
|
||||
@beta()
|
||||
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
|
||||
"""Convert nodes to documents.
|
||||
|
||||
Args:
|
||||
nodes: The nodes to convert to documents.
|
||||
Returns:
|
||||
The documents generated from the nodes.
|
||||
"""
|
||||
for node in nodes:
|
||||
metadata = node.metadata.copy()
|
||||
metadata[METADATA_LINKS_KEY] = [
|
||||
@@ -588,23 +599,28 @@ class GraphVectorStore(VectorStore):
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> "GraphVectorStoreRetriever":
|
||||
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
|
||||
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
|
||||
|
||||
Args:
|
||||
search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be "traversal" (default), "similarity", "mmr", or
|
||||
"similarity_score_threshold".
|
||||
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
k: Amount of documents to return (Default: 4)
|
||||
depth: The maximum depth of edges to traverse (Default: 1)
|
||||
score_threshold: Minimum relevance threshold
|
||||
for similarity_score_threshold
|
||||
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
|
||||
lambda_mult: Diversity of results returned by MMR;
|
||||
1 for minimum diversity and 0 for maximum. (Default: 0.5)
|
||||
**kwargs: Keyword arguments to pass to the search function.
|
||||
Can include:
|
||||
|
||||
- search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
|
||||
``similarity_score_threshold``.
|
||||
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
|
||||
- k(int): Amount of documents to return (Default: 4).
|
||||
- depth(int): The maximum depth of edges to traverse (Default: 1).
|
||||
- score_threshold(float): Minimum relevance threshold
|
||||
for similarity_score_threshold.
|
||||
- fetch_k(int): Amount of documents to pass to MMR algorithm
|
||||
(Default: 20).
|
||||
- lambda_mult(float): Diversity of results returned by MMR;
|
||||
1 for minimum diversity and 0 for maximum. (Default: 0.5).
|
||||
Returns:
|
||||
Retriever for this GraphVectorStore.
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Literal, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@beta()
|
||||
@dataclass(frozen=True)
|
||||
class Link:
|
||||
"""A link to/from a tag of a given tag.
|
||||
@@ -38,8 +40,10 @@ class Link:
|
||||
METADATA_LINKS_KEY = "links"
|
||||
|
||||
|
||||
@beta()
|
||||
def get_links(doc: Document) -> List[Link]:
|
||||
"""Get the links from a document.
|
||||
|
||||
Args:
|
||||
doc: The document to get the link tags from.
|
||||
Returns:
|
||||
@@ -54,8 +58,10 @@ def get_links(doc: Document) -> List[Link]:
|
||||
return links
|
||||
|
||||
|
||||
@beta()
|
||||
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
|
||||
"""Add links to the given metadata.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
@@ -68,6 +74,7 @@ def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
|
||||
links_in_metadata.append(link)
|
||||
|
||||
|
||||
@beta()
|
||||
def copy_with_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> Document:
|
||||
"""Return a document with the given links added.
|
||||
|
||||
|
||||
@@ -25,10 +25,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_core.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.indexing.base import DocumentIndex, RecordManager
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
# Magic UUID to use as a namespace for hashing.
|
||||
@@ -68,8 +69,9 @@ class _HashedDocument(Document):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@root_validator(pre=True)
|
||||
def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Root validator to calculate content and metadata hash."""
|
||||
content = values.get("page_content", "")
|
||||
metadata = values.get("metadata", {})
|
||||
@@ -91,7 +93,7 @@ class _HashedDocument(Document):
|
||||
raise ValueError(
|
||||
f"Failed to hash metadata: {e}. "
|
||||
f"Please use a dict that can be serialized using json."
|
||||
)
|
||||
) from e
|
||||
|
||||
values["content_hash"] = content_hash
|
||||
values["metadata_hash"] = metadata_hash
|
||||
@@ -435,7 +437,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
|
||||
async def aindex(
|
||||
docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]],
|
||||
record_manager: RecordManager,
|
||||
vectorstore: Union[VectorStore, DocumentIndex],
|
||||
vector_store: Union[VectorStore, DocumentIndex],
|
||||
*,
|
||||
batch_size: int = 100,
|
||||
cleanup: Literal["incremental", "full", None] = None,
|
||||
@@ -506,7 +508,7 @@ async def aindex(
|
||||
if cleanup == "incremental" and source_id_key is None:
|
||||
raise ValueError("Source id key is required when cleanup mode is incremental.")
|
||||
|
||||
destination = vectorstore # Renaming internally for clarity
|
||||
destination = vector_store # Renaming internally for clarity
|
||||
|
||||
# If it's a vectorstore, let's check if it has the required methods.
|
||||
if isinstance(destination, VectorStore):
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Sequence, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.indexing import UpsertResponse
|
||||
from langchain_core.indexing.base import DeleteResponse, DocumentIndex
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
@beta(message="Introduced in version 0.2.29. Underlying abstraction subject to change.")
|
||||
|
||||
@@ -39,6 +39,7 @@ https://python.langchain.com/v0.2/docs/how_to/custom_llm/
|
||||
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
LanguageModelLike,
|
||||
LanguageModelOutput,
|
||||
@@ -62,6 +63,7 @@ __all__ = [
|
||||
"LLM",
|
||||
"LanguageModelInput",
|
||||
"get_tokenizer",
|
||||
"LangSmithParams",
|
||||
"LanguageModelOutput",
|
||||
"LanguageModelLike",
|
||||
"FakeListLLM",
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@@ -17,7 +18,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import (
|
||||
@@ -27,7 +29,6 @@ from langchain_core.messages import (
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
|
||||
@@ -37,6 +38,23 @@ if TYPE_CHECKING:
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class LangSmithParams(TypedDict, total=False):
|
||||
"""LangSmith parameters for tracing."""
|
||||
|
||||
ls_provider: str
|
||||
"""Provider of the model."""
|
||||
ls_model_name: str
|
||||
"""Name of the model."""
|
||||
ls_model_type: Literal["chat", "llm"]
|
||||
"""Type of the model. Should be 'chat' or 'llm'."""
|
||||
ls_temperature: Optional[float]
|
||||
"""Temperature for generation."""
|
||||
ls_max_tokens: Optional[int]
|
||||
"""Max tokens for generation."""
|
||||
ls_stop: Optional[List[str]]
|
||||
"""Stop words for generation."""
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
"""Get a GPT-2 tokenizer instance.
|
||||
@@ -46,12 +64,12 @@ def get_tokenizer() -> Any:
|
||||
"""
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast # type: ignore[import]
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_token_ids. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
) from e
|
||||
# create a GPT-2 tokenizer instance
|
||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
@@ -95,7 +113,11 @@ class BaseLanguageModel(
|
||||
|
||||
Caching is not currently supported for streaming methods of models.
|
||||
"""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
# Repr = False is consistent with pydantic 1 if verbose = False
|
||||
# We can relax this for pydantic 2?
|
||||
# TODO(Team): decide what to do here.
|
||||
# Modified just to get unit tests to pass.
|
||||
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
@@ -108,6 +130,10 @@ class BaseLanguageModel(
|
||||
)
|
||||
"""Optional encoder to use for counting tokens."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@validator("verbose", pre=True, always=True, allow_reuse=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
@@ -220,7 +246,7 @@ class BaseLanguageModel(
|
||||
# generate responses that match a given schema.
|
||||
raise NotImplementedError()
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@abstractmethod
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
@@ -241,7 +267,7 @@ class BaseLanguageModel(
|
||||
Top model prediction as a string.
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self,
|
||||
@@ -266,7 +292,7 @@ class BaseLanguageModel(
|
||||
Top model prediction as a message.
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@abstractmethod
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
@@ -287,7 +313,7 @@ class BaseLanguageModel(
|
||||
Top model prediction as a string.
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self,
|
||||
|
||||
@@ -23,7 +23,12 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
@@ -36,7 +41,11 @@ from langchain_core.callbacks import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -55,11 +64,6 @@ from langchain_core.outputs import (
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
@@ -73,23 +77,6 @@ if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
class LangSmithParams(TypedDict, total=False):
|
||||
"""LangSmith parameters for tracing."""
|
||||
|
||||
ls_provider: str
|
||||
"""Provider of the model."""
|
||||
ls_model_name: str
|
||||
"""Name of the model."""
|
||||
ls_model_type: Literal["chat"]
|
||||
"""Type of the model. Should be 'chat'."""
|
||||
ls_temperature: Optional[float]
|
||||
"""Temperature for generation."""
|
||||
ls_max_tokens: Optional[int]
|
||||
"""Max tokens for generation."""
|
||||
ls_stop: Optional[List[str]]
|
||||
"""Stop words for generation."""
|
||||
|
||||
|
||||
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
"""Generate from a stream.
|
||||
|
||||
@@ -208,14 +195,40 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED] Callback manager to add to the run trace."""
|
||||
# TODO(0.3): Figure out how to re-apply deprecated decorator
|
||||
# callback_manager: Optional[BaseCallbackManager] = deprecated(
|
||||
# name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
|
||||
# )(
|
||||
# Field(
|
||||
# default=None,
|
||||
# exclude=True,
|
||||
# description="Callback manager to add to the run trace.",
|
||||
# )
|
||||
# )
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(
|
||||
default=None,
|
||||
exclude=True,
|
||||
description="Callback manager to add to the run trace.",
|
||||
)
|
||||
|
||||
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
||||
"""An optional rate limiter to use for limiting the number of requests."""
|
||||
"An optional rate limiter to use for limiting the number of requests."
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]] = False
|
||||
"""Whether to disable streaming for this model.
|
||||
|
||||
If streaming is bypassed, then ``stream()/astream()`` will defer to
|
||||
``invoke()/ainvoke()``.
|
||||
|
||||
- If True, will always bypass streaming case.
|
||||
- If "tool_calling", will bypass streaming case only when the model is called
|
||||
with a ``tools`` keyword argument.
|
||||
- If False (default), will always use streaming case if available.
|
||||
"""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
|
||||
Args:
|
||||
@@ -231,12 +244,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@@ -302,6 +317,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
||||
def _should_stream(
|
||||
self,
|
||||
*,
|
||||
async_api: bool,
|
||||
run_manager: Optional[
|
||||
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Determine if a given model call should hit the streaming API."""
|
||||
sync_not_implemented = type(self)._stream == BaseChatModel._stream
|
||||
async_not_implemented = type(self)._astream == BaseChatModel._astream
|
||||
|
||||
# Check if streaming is implemented.
|
||||
if (not async_api) and sync_not_implemented:
|
||||
return False
|
||||
# Note, since async falls back to sync we check both here.
|
||||
if async_api and async_not_implemented and sync_not_implemented:
|
||||
return False
|
||||
|
||||
# Check if streaming has been disabled on this instance.
|
||||
if self.disable_streaming is True:
|
||||
return False
|
||||
# We assume tools are passed in via "tools" kwarg in all models.
|
||||
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
|
||||
return False
|
||||
|
||||
# Check if a runtime streaming flag has been passed in.
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# Check if any streaming callback handlers have been passed in.
|
||||
handlers = run_manager.handlers if run_manager else []
|
||||
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@@ -310,7 +360,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if type(self)._stream == BaseChatModel._stream:
|
||||
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
@@ -380,10 +430,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if (
|
||||
type(self)._astream is BaseChatModel._astream
|
||||
and type(self)._stream is BaseChatModel._stream
|
||||
):
|
||||
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
|
||||
# No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield cast(
|
||||
BaseMessageChunk,
|
||||
@@ -471,9 +518,37 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
ls_params = LangSmithParams(ls_model_type="chat")
|
||||
|
||||
# get default provider from class name
|
||||
default_provider = self.__class__.__name__
|
||||
if default_provider.startswith("Chat"):
|
||||
default_provider = default_provider[4:].lower()
|
||||
elif default_provider.endswith("Chat"):
|
||||
default_provider = default_provider[:-4]
|
||||
default_provider = default_provider.lower()
|
||||
|
||||
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
|
||||
if stop:
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
||||
# temperature
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
||||
ls_params["ls_temperature"] = kwargs["temperature"]
|
||||
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
||||
ls_params["ls_temperature"] = self.temperature
|
||||
|
||||
# max_tokens
|
||||
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
|
||||
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
|
||||
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
|
||||
ls_params["ls_max_tokens"] = self.max_tokens
|
||||
|
||||
return ls_params
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
@@ -760,20 +835,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
|
||||
"stream",
|
||||
(
|
||||
next(
|
||||
(
|
||||
True
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, _StreamingCallbackHandler)
|
||||
),
|
||||
False,
|
||||
)
|
||||
if run_manager
|
||||
else False
|
||||
),
|
||||
if self._should_stream(
|
||||
async_api=False,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
@@ -847,23 +912,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||
if (
|
||||
type(self)._astream != BaseChatModel._astream
|
||||
or type(self)._stream != BaseChatModel._stream
|
||||
) and kwargs.pop(
|
||||
"stream",
|
||||
(
|
||||
next(
|
||||
(
|
||||
True
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, _StreamingCallbackHandler)
|
||||
),
|
||||
False,
|
||||
)
|
||||
if run_manager
|
||||
else False
|
||||
),
|
||||
if self._should_stream(
|
||||
async_api=True,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||
@@ -963,7 +1015,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
break
|
||||
yield item # type: ignore[misc]
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@@ -995,13 +1047,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def call_as_llm(
|
||||
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@@ -1015,7 +1067,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
else:
|
||||
raise ValueError("Cannot use predict when output is not a string.")
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@@ -1029,7 +1081,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
_stop = list(stop)
|
||||
return self(messages, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@@ -1045,7 +1097,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
else:
|
||||
raise ValueError("Cannot use predict when output is not a string.")
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
||||
@@ -69,6 +69,10 @@ class FakeListLLM(LLM):
|
||||
return {"responses": self.responses}
|
||||
|
||||
|
||||
class FakeListLLMError(Exception):
|
||||
"""Fake error for testing purposes."""
|
||||
|
||||
|
||||
class FakeStreamingListLLM(FakeListLLM):
|
||||
"""Fake streaming list LLM for testing purposes.
|
||||
|
||||
@@ -98,7 +102,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
raise FakeListLLMError
|
||||
yield c
|
||||
|
||||
async def astream(
|
||||
@@ -118,5 +122,5 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
raise FakeListLLMError
|
||||
yield c
|
||||
|
||||
@@ -44,6 +44,10 @@ class FakeMessagesListChatModel(BaseChatModel):
|
||||
return "fake-messages-list-chat-model"
|
||||
|
||||
|
||||
class FakeListChatModelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
@@ -93,7 +97,7 @@ class FakeListChatModel(SimpleChatModel):
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
raise FakeListChatModelError
|
||||
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@@ -116,7 +120,7 @@ class FakeListChatModel(SimpleChatModel):
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
raise FakeListChatModelError
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import yaml
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from tenacity import (
|
||||
RetryCallState,
|
||||
before_sleep_log,
|
||||
@@ -48,7 +49,11 @@ from langchain_core.callbacks import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -58,7 +63,6 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
@@ -110,13 +114,12 @@ def create_base_retry_decorator(
|
||||
_log_error_once(f"Error in on_retry: {e}")
|
||||
else:
|
||||
run_manager.on_retry(retry_state)
|
||||
return None
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
|
||||
retry_instance: retry_base = retry_if_exception_type(error_types[0])
|
||||
for error in error_types[1:]:
|
||||
retry_instance = retry_instance | retry_if_exception_type(error)
|
||||
return retry(
|
||||
@@ -297,16 +300,19 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED]"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
@@ -331,6 +337,43 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
|
||||
# get default provider from class name
|
||||
default_provider = self.__class__.__name__
|
||||
if default_provider.endswith("LLM"):
|
||||
default_provider = default_provider[:-3]
|
||||
default_provider = default_provider.lower()
|
||||
|
||||
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="llm")
|
||||
if stop:
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
||||
# temperature
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
||||
ls_params["ls_temperature"] = kwargs["temperature"]
|
||||
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
||||
ls_params["ls_temperature"] = self.temperature
|
||||
|
||||
# max_tokens
|
||||
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
|
||||
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
|
||||
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
|
||||
ls_params["ls_max_tokens"] = self.max_tokens
|
||||
|
||||
return ls_params
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@@ -487,13 +530,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_llm_start(
|
||||
@@ -548,13 +595,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_llm_start(
|
||||
@@ -796,6 +847,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
# Create callback managers
|
||||
if isinstance(metadata, list):
|
||||
metadata = [
|
||||
{
|
||||
**(meta or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
for meta in metadata
|
||||
]
|
||||
elif isinstance(metadata, dict):
|
||||
metadata = {
|
||||
**(metadata or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
else:
|
||||
pass
|
||||
if (
|
||||
isinstance(callbacks, list)
|
||||
and callbacks
|
||||
@@ -1017,6 +1083,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
if isinstance(metadata, list):
|
||||
metadata = [
|
||||
{
|
||||
**(meta or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
for meta in metadata
|
||||
]
|
||||
elif isinstance(metadata, dict):
|
||||
metadata = {
|
||||
**(metadata or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
else:
|
||||
pass
|
||||
# Create callback managers
|
||||
if isinstance(callbacks, list) and (
|
||||
isinstance(callbacks[0], (list, BaseCallbackManager))
|
||||
@@ -1150,7 +1231,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -1220,7 +1301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
)
|
||||
return result.generations[0][0].text
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@@ -1230,7 +1311,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
_stop = list(stop)
|
||||
return self(text, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@@ -1246,7 +1327,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
content = self(text, stop=_stop, **kwargs)
|
||||
return AIMessage(content=content)
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@@ -1256,7 +1337,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
_stop = list(stop)
|
||||
return await self._call_async(text, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
||||
@@ -10,9 +10,10 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict # pydantic: ignore
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
@@ -80,7 +81,7 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
Exception: If the key is not in the model.
|
||||
"""
|
||||
try:
|
||||
return model.__fields__[key].get_default() != value
|
||||
return model.model_fields[key].get_default() != value
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
@@ -161,16 +162,25 @@ class Serializable(BaseModel, ABC):
|
||||
For example, for the class `langchain.llms.openai.OpenAI`, the id is
|
||||
["langchain", "llms", "openai", "OpenAI"].
|
||||
"""
|
||||
return [*cls.get_lc_namespace(), cls.__name__]
|
||||
# Pydantic generics change the class name. So we need to do the following
|
||||
if (
|
||||
"origin" in cls.__pydantic_generic_metadata__
|
||||
and cls.__pydantic_generic_metadata__["origin"] is not None
|
||||
):
|
||||
original_name = cls.__pydantic_generic_metadata__["origin"].__name__
|
||||
else:
|
||||
original_name = cls.__name__
|
||||
return [*cls.get_lc_namespace(), original_name]
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
model_config = ConfigDict(
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or try_neq_default(v, k, self))
|
||||
if (k not in self.model_fields or try_neq_default(v, k, self))
|
||||
]
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
@@ -184,12 +194,15 @@ class Serializable(BaseModel, ABC):
|
||||
|
||||
secrets = dict()
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {
|
||||
k: getattr(self, k, v)
|
||||
for k, v in self
|
||||
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
|
||||
and _is_field_useful(self, k, v)
|
||||
}
|
||||
lc_kwargs = {}
|
||||
for k, v in self:
|
||||
if not _is_field_useful(self, k, v):
|
||||
continue
|
||||
# Do nothing if the field is excluded
|
||||
if k in self.model_fields and self.model_fields[k].exclude:
|
||||
continue
|
||||
|
||||
lc_kwargs[k] = getattr(self, k, v)
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
for cls in [None, *self.__class__.mro()]:
|
||||
@@ -221,8 +234,10 @@ class Serializable(BaseModel, ABC):
|
||||
# that are not present in the fields.
|
||||
for key in list(secrets):
|
||||
value = secrets[key]
|
||||
if key in this.__fields__:
|
||||
secrets[this.__fields__[key].alias] = value
|
||||
if key in this.model_fields:
|
||||
alias = this.model_fields[key].alias
|
||||
if alias is not None:
|
||||
secrets[alias] = value
|
||||
lc_kwargs.update(this.lc_attributes)
|
||||
|
||||
# include all secrets, even if not specified in kwargs
|
||||
@@ -244,6 +259,10 @@ class Serializable(BaseModel, ABC):
|
||||
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
|
||||
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
"""Check if a field is useful as a constructor argument.
|
||||
@@ -259,10 +278,46 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
If the field is not required and the value is None, it is useful if the
|
||||
default value is different from the value.
|
||||
"""
|
||||
field = inst.__fields__.get(key)
|
||||
field = inst.model_fields.get(key)
|
||||
if not field:
|
||||
return False
|
||||
return field.required is True or value or field.get_default() != value
|
||||
|
||||
if field.is_required():
|
||||
return True
|
||||
|
||||
# Handle edge case: a value cannot be converted to a boolean (e.g. a
|
||||
# Pandas DataFrame).
|
||||
try:
|
||||
value_is_truthy = bool(value)
|
||||
except Exception as _:
|
||||
value_is_truthy = False
|
||||
|
||||
if value_is_truthy:
|
||||
return True
|
||||
|
||||
# Value is still falsy here!
|
||||
if field.default_factory is dict and isinstance(value, dict):
|
||||
return False
|
||||
|
||||
# Value is still falsy here!
|
||||
if field.default_factory is list and isinstance(value, list):
|
||||
return False
|
||||
|
||||
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
|
||||
# Pandas DataFrames).
|
||||
try:
|
||||
value_neq_default = bool(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
value_neq_default = all(field.get_default() != value)
|
||||
except Exception as _:
|
||||
try:
|
||||
value_neq_default = value is not field.default
|
||||
except Exception as _:
|
||||
value_neq_default = False
|
||||
|
||||
# If value is falsy and does not match the default
|
||||
return value_is_truthy or value_neq_default
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
|
||||
@@ -13,6 +13,8 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
@@ -47,8 +49,9 @@ class BaseMemory(Serializable, ABC):
|
||||
pass
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self, TypedDict
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@@ -24,7 +25,6 @@ from langchain_core.messages.tool import (
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call_chunk as create_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
|
||||
@@ -111,8 +111,9 @@ class AIMessage(BaseMessage):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator(pre=True)
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> Any:
|
||||
check_additional_kwargs = not any(
|
||||
values.get(k)
|
||||
for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
|
||||
@@ -204,7 +205,7 @@ class AIMessage(BaseMessage):
|
||||
return (base.strip() + "\n" + "\n".join(lines)).strip()
|
||||
|
||||
|
||||
AIMessage.update_forward_refs()
|
||||
AIMessage.model_rebuild()
|
||||
|
||||
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
@@ -238,8 +239,8 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def init_tool_calls(cls, values: dict) -> dict:
|
||||
@model_validator(mode="after")
|
||||
def init_tool_calls(self) -> Self:
|
||||
"""Initialize tool calls from tool call chunks.
|
||||
|
||||
Args:
|
||||
@@ -251,35 +252,35 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
Raises:
|
||||
ValueError: If the tool call chunks are malformed.
|
||||
"""
|
||||
if not values["tool_call_chunks"]:
|
||||
if values["tool_calls"]:
|
||||
values["tool_call_chunks"] = [
|
||||
if not self.tool_call_chunks:
|
||||
if self.tool_calls:
|
||||
self.tool_call_chunks = [
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"],
|
||||
args=json.dumps(tc["args"]),
|
||||
id=tc["id"],
|
||||
index=None,
|
||||
)
|
||||
for tc in values["tool_calls"]
|
||||
for tc in self.tool_calls
|
||||
]
|
||||
if values["invalid_tool_calls"]:
|
||||
tool_call_chunks = values.get("tool_call_chunks", [])
|
||||
if self.invalid_tool_calls:
|
||||
tool_call_chunks = self.tool_call_chunks
|
||||
tool_call_chunks.extend(
|
||||
[
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
||||
)
|
||||
for tc in values["invalid_tool_calls"]
|
||||
for tc in self.invalid_tool_calls
|
||||
]
|
||||
)
|
||||
values["tool_call_chunks"] = tool_call_chunks
|
||||
self.tool_call_chunks = tool_call_chunks
|
||||
|
||||
return values
|
||||
return self
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
for chunk in values["tool_call_chunks"]:
|
||||
for chunk in self.tool_call_chunks:
|
||||
try:
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
create_tool_call(
|
||||
@@ -299,9 +300,9 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
values["tool_calls"] = tool_calls
|
||||
values["invalid_tool_calls"] = invalid_tool_calls
|
||||
return values
|
||||
self.tool_calls = tool_calls
|
||||
self.invalid_tool_calls = invalid_tool_calls
|
||||
return self
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, AIMessageChunk):
|
||||
|
||||
@@ -2,11 +2,13 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Extra, Field
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
@@ -51,8 +53,9 @@ class BaseMessage(Serializable):
|
||||
"""An optional unique identifier for the message. This should ideally be
|
||||
provided by the provider/model which created the message."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
model_config = ConfigDict(
|
||||
extra=Extra.allow,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
@@ -108,6 +111,10 @@ class BaseMessage(Serializable):
|
||||
def pretty_print(self) -> None:
|
||||
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
|
||||
def merge_content(
|
||||
first_content: Union[str, List[Union[str, Dict]]],
|
||||
|
||||
@@ -25,7 +25,7 @@ class ChatMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
ChatMessage.update_forward_refs()
|
||||
ChatMessage.model_rebuild()
|
||||
|
||||
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
|
||||
@@ -32,7 +32,7 @@ class FunctionMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
FunctionMessage.update_forward_refs()
|
||||
FunctionMessage.model_rebuild()
|
||||
|
||||
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
|
||||
@@ -56,7 +56,7 @@ class HumanMessage(BaseMessage):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
HumanMessage.update_forward_refs()
|
||||
HumanMessage.model_rebuild()
|
||||
|
||||
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
|
||||
@@ -33,4 +33,4 @@ class RemoveMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
RemoveMessage.update_forward_refs()
|
||||
RemoveMessage.model_rebuild()
|
||||
|
||||
@@ -50,7 +50,7 @@ class SystemMessage(BaseMessage):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
SystemMessage.update_forward_refs()
|
||||
SystemMessage.model_rebuild()
|
||||
|
||||
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
@@ -70,6 +71,11 @@ class ToolMessage(BaseMessage):
|
||||
.. versionadded:: 0.2.24
|
||||
"""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict, repr=False)
|
||||
"""Currently inherited from BaseMessage, but not used."""
|
||||
response_metadata: dict = Field(default_factory=dict, repr=False)
|
||||
"""Currently inherited from BaseMessage, but not used."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
@@ -88,7 +94,7 @@ class ToolMessage(BaseMessage):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
ToolMessage.update_forward_refs()
|
||||
ToolMessage.model_rebuild()
|
||||
|
||||
|
||||
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
|
||||
@@ -10,6 +10,7 @@ Some examples of what you can do with these functions include:
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -213,7 +214,23 @@ def _create_message_from_message_type(
|
||||
if id is not None:
|
||||
kwargs["id"] = id
|
||||
if tool_calls is not None:
|
||||
kwargs["tool_calls"] = tool_calls
|
||||
kwargs["tool_calls"] = []
|
||||
for tool_call in tool_calls:
|
||||
# Convert OpenAI-format tool call to LangChain format.
|
||||
if "function" in tool_call:
|
||||
args = tool_call["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args, strict=False)
|
||||
kwargs["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": args,
|
||||
"id": tool_call["id"],
|
||||
"type": "tool_call",
|
||||
}
|
||||
)
|
||||
else:
|
||||
kwargs["tool_calls"].append(tool_call)
|
||||
if message_type in ("human", "user"):
|
||||
message: BaseMessage = HumanMessage(content=content, **kwargs)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
@@ -271,11 +288,12 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
||||
msg_type = msg_kwargs.pop("role")
|
||||
except KeyError:
|
||||
msg_type = msg_kwargs.pop("type")
|
||||
msg_content = msg_kwargs.pop("content")
|
||||
except KeyError:
|
||||
# None msg content is not allowed
|
||||
msg_content = msg_kwargs.pop("content") or ""
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||
)
|
||||
) from e
|
||||
_message = _create_message_from_message_type(
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
@@ -326,9 +344,7 @@ def _runnable_support(func: Callable) -> Callable:
|
||||
if messages is not None:
|
||||
return func(messages, **kwargs)
|
||||
else:
|
||||
return RunnableLambda(
|
||||
partial(func, **kwargs), name=getattr(func, "__name__")
|
||||
)
|
||||
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
|
||||
|
||||
wrapped.__doc__ = func.__doc__
|
||||
return wrapped
|
||||
@@ -425,6 +441,8 @@ def filter_messages(
|
||||
@_runnable_support
|
||||
def merge_message_runs(
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
*,
|
||||
chunk_separator: str = "\n",
|
||||
) -> List[BaseMessage]:
|
||||
"""Merge consecutive Messages of the same type.
|
||||
|
||||
@@ -433,13 +451,16 @@ def merge_message_runs(
|
||||
|
||||
Args:
|
||||
messages: Sequence Message-like objects to merge.
|
||||
chunk_separator: Specify the string to be inserted between message chunks.
|
||||
Default is "\n".
|
||||
|
||||
Returns:
|
||||
List of BaseMessages with consecutive runs of message types merged into single
|
||||
messages. If two messages being merged both have string contents, the merged
|
||||
content is a concatenation of the two strings with a new-line separator. If at
|
||||
least one of the messages has a list of content blocks, the merged content is a
|
||||
list of content blocks.
|
||||
messages. By default, if two messages being merged both have string contents,
|
||||
the merged content is a concatenation of the two strings with a new-line separator.
|
||||
The separator inserted between message chunks can be controlled by specifying
|
||||
any string with ``chunk_separator``. If at least one of the messages has a list of
|
||||
content blocks, the merged content is a list of content blocks.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -509,11 +530,13 @@ def merge_message_runs(
|
||||
and last_chunk.content
|
||||
and curr_chunk.content
|
||||
):
|
||||
last_chunk.content += "\n"
|
||||
last_chunk.content += chunk_separator
|
||||
merged.append(_chunk_to_msg(last_chunk + curr_chunk))
|
||||
return merged
|
||||
|
||||
|
||||
# TODO: Update so validation errors (for token_counter, for example) are raised on
|
||||
# init not at runtime.
|
||||
@_runnable_support
|
||||
def trim_messages(
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
@@ -759,24 +782,30 @@ def trim_messages(
|
||||
AIMessage("This is a 4 token text. The full message is 10 tokens.", id="fourth"),
|
||||
]
|
||||
""" # noqa: E501
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
if start_on and strategy == "first":
|
||||
raise ValueError
|
||||
if include_system and strategy == "first":
|
||||
raise ValueError
|
||||
messages = convert_to_messages(messages)
|
||||
if isinstance(token_counter, BaseLanguageModel):
|
||||
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||
elif (
|
||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
||||
is BaseMessage
|
||||
):
|
||||
elif callable(token_counter):
|
||||
if (
|
||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
||||
is BaseMessage
|
||||
):
|
||||
|
||||
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
|
||||
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
|
||||
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
|
||||
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
|
||||
else:
|
||||
list_token_counter = token_counter # type: ignore[assignment]
|
||||
else:
|
||||
list_token_counter = token_counter # type: ignore[assignment]
|
||||
raise ValueError(
|
||||
f"'token_counter' expected to be a model that implements "
|
||||
f"'get_num_tokens_from_messages()' or a function. Received object of type "
|
||||
f"{type(token_counter)}."
|
||||
)
|
||||
|
||||
try:
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
||||
@@ -13,8 +13,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain_core.language_models import LanguageModelOutput
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
@@ -166,10 +164,11 @@ class BaseOutputParser(
|
||||
Raises:
|
||||
TypeError: If the class doesn't have an inferable OutputType.
|
||||
"""
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 1:
|
||||
return type_args[0]
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) > 0:
|
||||
return metadata["args"][0]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||
|
||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, Optional, Type, TypeVar, Union
|
||||
from typing import Annotated, Any, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
import pydantic # pydantic: ignore
|
||||
from pydantic import SkipValidation # pydantic: ignore
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
@@ -40,7 +41,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
|
||||
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore
|
||||
"""The Pydantic object to use for validation.
|
||||
If None, no validation is performed."""
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
@@ -122,6 +123,9 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
||||
yield [part]
|
||||
|
||||
|
||||
ListOutputParser.model_rebuild()
|
||||
|
||||
|
||||
class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import (
|
||||
@@ -11,7 +12,6 @@ from langchain_core.output_parsers import (
|
||||
)
|
||||
from langchain_core.output_parsers.json import parse_partial_json
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
@@ -42,7 +42,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
try:
|
||||
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
|
||||
except KeyError as exc:
|
||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call: {exc}"
|
||||
) from exc
|
||||
|
||||
if self.args_only:
|
||||
return func_call["arguments"]
|
||||
@@ -100,7 +102,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
if partial:
|
||||
return None
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call: {exc}"
|
||||
) from exc
|
||||
try:
|
||||
if partial:
|
||||
try:
|
||||
@@ -126,7 +130,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
) from exc
|
||||
else:
|
||||
try:
|
||||
return {
|
||||
@@ -138,7 +142,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
) from exc
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
@@ -226,8 +230,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
determine which schema to use.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_schema(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_schema(cls, values: Dict) -> Any:
|
||||
"""Validate the pydantic schema.
|
||||
|
||||
Args:
|
||||
@@ -263,11 +268,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
"""
|
||||
_result = super().parse_result(result)
|
||||
if self.args_only:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
if hasattr(self.pydantic_schema, "model_validate_json"):
|
||||
pydantic_args = self.pydantic_schema.model_validate_json(_result)
|
||||
else:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
else:
|
||||
fn_name = _result["name"]
|
||||
_args = _result["arguments"]
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
|
||||
if hasattr(self.pydantic_schema, "model_validate_json"):
|
||||
pydantic_args = self.pydantic_schema[fn_name].model_validate_json(_args) # type: ignore
|
||||
else:
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
|
||||
return pydantic_args
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import copy
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Annotated, Any, Dict, List, Optional
|
||||
|
||||
from pydantic import SkipValidation, ValidationError # pydantic: ignore
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
from langchain_core.messages.tool import (
|
||||
invalid_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import ValidationError
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
@@ -59,7 +56,7 @@ def parse_tool_call(
|
||||
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
|
||||
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
|
||||
f"Received JSONDecodeError {e}"
|
||||
)
|
||||
) from e
|
||||
parsed = {
|
||||
"name": raw_tool_call["function"]["name"] or "",
|
||||
"args": function_args or {},
|
||||
@@ -256,7 +253,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
tools: List[TypeBaseModel]
|
||||
tools: Annotated[List[TypeBaseModel], SkipValidation()]
|
||||
"""The tools to parse."""
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from typing import Generic, List, Type
|
||||
from typing import Annotated, Generic, List, Optional, Type
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
from pydantic import SkipValidation # pydantic: ignore
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
@@ -16,7 +17,7 @@ from langchain_core.utils.pydantic import (
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
"""Parse an output using a pydantic model."""
|
||||
|
||||
pydantic_object: Type[TBaseModel] # type: ignore
|
||||
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
|
||||
"""The pydantic model to parse."""
|
||||
|
||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||
@@ -32,12 +33,12 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
{self.pydantic_object.__class__}"
|
||||
)
|
||||
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
||||
raise self._parser_exception(e, obj)
|
||||
raise self._parser_exception(e, obj) from e
|
||||
else: # pydantic v1
|
||||
try:
|
||||
return self.pydantic_object.parse_obj(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
raise self._parser_exception(e, obj)
|
||||
raise self._parser_exception(e, obj) from e
|
||||
|
||||
def _parser_exception(
|
||||
self, e: Exception, json_object: dict
|
||||
@@ -49,7 +50,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> TBaseModel:
|
||||
) -> Optional[TBaseModel]:
|
||||
"""Parse the result of an LLM call to a pydantic object.
|
||||
|
||||
Args:
|
||||
@@ -62,8 +63,13 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
Returns:
|
||||
The parsed pydantic object.
|
||||
"""
|
||||
json_object = super().parse_result(result)
|
||||
return self._parse_obj(json_object)
|
||||
try:
|
||||
json_object = super().parse_result(result)
|
||||
return self._parse_obj(json_object)
|
||||
except OutputParserException as e:
|
||||
if partial:
|
||||
return None
|
||||
raise e
|
||||
|
||||
def parse(self, text: str) -> TBaseModel:
|
||||
"""Parse the output of an LLM call to a pydantic object.
|
||||
@@ -92,7 +98,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
if "type" in reduced_schema:
|
||||
del reduced_schema["type"]
|
||||
# Ensure json in context is well-formed with double quotes.
|
||||
schema_str = json.dumps(reduced_schema)
|
||||
schema_str = json.dumps(reduced_schema, ensure_ascii=False)
|
||||
|
||||
return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
||||
|
||||
@@ -106,6 +112,9 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return self.pydantic_object
|
||||
|
||||
|
||||
PydanticOutputParser.model_rebuild()
|
||||
|
||||
|
||||
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import List
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
|
||||
@@ -24,3 +25,6 @@ class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
def parse(self, text: str) -> str:
|
||||
"""Returns the input text with no changes."""
|
||||
return text
|
||||
|
||||
|
||||
StrOutputParser.model_rebuild()
|
||||
|
||||
@@ -46,12 +46,12 @@ class _StreamingParser:
|
||||
if parser == "defusedxml":
|
||||
try:
|
||||
from defusedxml import ElementTree as DET # type: ignore
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"defusedxml is not installed. "
|
||||
"Please install it to use the defusedxml parser."
|
||||
"You can install it with `pip install defusedxml` "
|
||||
)
|
||||
) from e
|
||||
_parser = DET.DefusedXMLParser(target=TreeBuilder())
|
||||
else:
|
||||
_parser = None
|
||||
@@ -189,13 +189,13 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
if self.parser == "defusedxml":
|
||||
try:
|
||||
from defusedxml import ElementTree as DET # type: ignore
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"defusedxml is not installed. "
|
||||
"Please install it to use the defusedxml parser."
|
||||
"You can install it with `pip install defusedxml`"
|
||||
"See https://github.com/tiran/defusedxml for more details"
|
||||
)
|
||||
) from e
|
||||
_ET = DET # Use the defusedxml parser
|
||||
else:
|
||||
_ET = ET # Use the standard library parser
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
@@ -30,8 +32,8 @@ class ChatGeneration(Generation):
|
||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="after")
|
||||
def set_text(self) -> Self:
|
||||
"""Set the text attribute to be the contents of the message.
|
||||
|
||||
Args:
|
||||
@@ -45,12 +47,12 @@ class ChatGeneration(Generation):
|
||||
"""
|
||||
try:
|
||||
text = ""
|
||||
if isinstance(values["message"].content, str):
|
||||
text = values["message"].content
|
||||
if isinstance(self.message.content, str):
|
||||
text = self.message.content
|
||||
# HACK: Assumes text in content blocks in OpenAI format.
|
||||
# Uses first text block.
|
||||
elif isinstance(values["message"].content, list):
|
||||
for block in values["message"].content:
|
||||
elif isinstance(self.message.content, list):
|
||||
for block in self.message.content:
|
||||
if isinstance(block, str):
|
||||
text = block
|
||||
break
|
||||
@@ -61,10 +63,10 @@ class ChatGeneration(Generation):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
values["text"] = text
|
||||
self.text = text
|
||||
except (KeyError, AttributeError) as e:
|
||||
raise ValueError("Error while initializing ChatGeneration") from e
|
||||
return values
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class ChatResult(BaseModel):
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
|
||||
from langchain_core.outputs.generation import Generation, GenerationChunk
|
||||
from langchain_core.outputs.run_info import RunInfo
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
@@ -16,7 +18,9 @@ class LLMResult(BaseModel):
|
||||
wants to return.
|
||||
"""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
generations: List[
|
||||
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
|
||||
]
|
||||
"""Generated outputs.
|
||||
|
||||
The first dimension of the list represents completions for different input
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RunInfo(BaseModel):
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import (
|
||||
)
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
from langchain_core.prompt_values import (
|
||||
@@ -25,7 +27,6 @@ from langchain_core.prompt_values import (
|
||||
PromptValue,
|
||||
StringPromptValue,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.config import ensure_config
|
||||
from langchain_core.runnables.utils import create_model
|
||||
@@ -64,28 +65,26 @@ class BasePromptTemplate(
|
||||
tags: Optional[List[str]] = None
|
||||
"""Tags to be used for tracing."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_variable_names(self) -> Self:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
if "stop" in self.input_variables:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
if "stop" in self.partial_variables:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
overall = set(self.input_variables).intersection(self.partial_variables)
|
||||
if overall:
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
@@ -99,8 +98,9 @@ class BasePromptTemplate(
|
||||
Returns True."""
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
@@ -129,7 +129,7 @@ class BasePromptTemplate(
|
||||
"PromptInput", **{**required_input_variables, **optional_input_variables}
|
||||
)
|
||||
|
||||
def _validate_input(self, inner_input: Dict) -> Dict:
|
||||
def _validate_input(self, inner_input: Any) -> Dict:
|
||||
if not isinstance(inner_input, dict):
|
||||
if len(self.input_variables) == 1:
|
||||
var_name = self.input_variables[0]
|
||||
@@ -142,11 +142,18 @@ class BasePromptTemplate(
|
||||
)
|
||||
missing = set(self.input_variables).difference(inner_input)
|
||||
if missing:
|
||||
raise KeyError(
|
||||
msg = (
|
||||
f"Input to {self.__class__.__name__} is missing variables {missing}. "
|
||||
f" Expected: {self.input_variables}"
|
||||
f" Received: {list(inner_input.keys())}"
|
||||
)
|
||||
example_key = missing.pop()
|
||||
msg += (
|
||||
f"\nNote: if you intended {{{example_key}}} to be part of the string"
|
||||
" and not a variable, please escape it with double curly braces like: "
|
||||
f"'{{{{{example_key}}}}}'."
|
||||
)
|
||||
raise KeyError(msg)
|
||||
return inner_input
|
||||
|
||||
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
@@ -21,6 +22,13 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
PositiveInt,
|
||||
SkipValidation,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import (
|
||||
@@ -38,7 +46,6 @@ from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.image import ImagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
||||
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
@@ -207,8 +214,14 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
||||
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
|
||||
def __init__(
|
||||
self, variable_name: str, *, optional: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
# mypy can't detect the init which is defined in the parent class
|
||||
# b/c these are BaseModel classes.
|
||||
super().__init__( # type: ignore
|
||||
variable_name=variable_name, optional=optional, **kwargs
|
||||
)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
@@ -551,13 +564,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
input_variables=input_variables, template=img_template
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
raise ValueError(f"Invalid image template: {tmpl}")
|
||||
prompt.append(img_template_obj)
|
||||
else:
|
||||
raise ValueError()
|
||||
raise ValueError(f"Invalid template: {tmpl}")
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
else:
|
||||
raise ValueError()
|
||||
raise ValueError(f"Invalid template: {template}")
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
@@ -576,7 +589,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
with open(str(template_file), "r") as f:
|
||||
with open(str(template_file)) as f:
|
||||
template = f.read()
|
||||
return cls.from_template(template, input_variables=input_variables, **kwargs)
|
||||
|
||||
@@ -922,7 +935,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
messages: List[MessageLike]
|
||||
messages: Annotated[List[MessageLike], SkipValidation]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
@@ -1038,8 +1051,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_input_variables(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_input_variables(cls, values: dict) -> Any:
|
||||
"""Validate input variables.
|
||||
|
||||
If input_variables is not set, it will be set to the union of
|
||||
@@ -1177,7 +1191,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
A message can be represented using the following formats:
|
||||
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
|
||||
(message type, template); e.g., ("human", "{user_input}"),
|
||||
(4) 2-tuple of (message class, template), (4) a string which is
|
||||
(4) 2-tuple of (message class, template), (5) a string which is
|
||||
shorthand for ("human", template); e.g., "{user_input}".
|
||||
template_format: format of the template. Defaults to "f-string".
|
||||
|
||||
|
||||
@@ -5,6 +5,15 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Extra,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts.chat import (
|
||||
@@ -18,7 +27,6 @@ from langchain_core.prompts.string import (
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
class _FewShotPromptTemplateMixin(BaseModel):
|
||||
@@ -32,12 +40,14 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra=Extra.forbid,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_examples_and_selector(cls, values: Dict) -> Any:
|
||||
"""Check that one and only one of examples/example_selector are provided.
|
||||
|
||||
Args:
|
||||
@@ -139,28 +149,29 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
kwargs["input_variables"] = kwargs["example_prompt"].input_variables
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def template_is_valid(self) -> Self:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
if self.validate_template:
|
||||
check_valid_template(
|
||||
values["prefix"] + values["suffix"],
|
||||
values["template_format"],
|
||||
values["input_variables"] + list(values["partial_variables"]),
|
||||
self.prefix + self.suffix,
|
||||
self.template_format,
|
||||
self.input_variables + list(self.partial_variables),
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
elif self.template_format or None:
|
||||
self.input_variables = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["prefix"] + values["suffix"], values["template_format"]
|
||||
self.prefix + self.suffix, self.template_format
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
if var not in self.partial_variables
|
||||
]
|
||||
return values
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra=Extra.forbid,
|
||||
)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with inputs generating a string.
|
||||
@@ -365,9 +376,10 @@ class FewShotChatMessagePromptTemplate(
|
||||
"""Return whether or not the class is serializable."""
|
||||
return False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra=Extra.forbid,
|
||||
)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages.
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, Extra, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
|
||||
|
||||
class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
@@ -45,8 +47,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "few_shot_with_templates"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_examples_and_selector(cls, values: Dict) -> Any:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
@@ -62,15 +65,15 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def template_is_valid(self) -> Self:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
input_variables = values["input_variables"]
|
||||
expected_input_variables = set(values["suffix"].input_variables)
|
||||
expected_input_variables |= set(values["partial_variables"])
|
||||
if values["prefix"] is not None:
|
||||
expected_input_variables |= set(values["prefix"].input_variables)
|
||||
if self.validate_template:
|
||||
input_variables = self.input_variables
|
||||
expected_input_variables = set(self.suffix.input_variables)
|
||||
expected_input_variables |= set(self.partial_variables)
|
||||
if self.prefix is not None:
|
||||
expected_input_variables |= set(self.prefix.input_variables)
|
||||
missing_vars = expected_input_variables.difference(input_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
@@ -78,16 +81,17 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
f"prefix/suffix expected {expected_input_variables}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(
|
||||
set(values["suffix"].input_variables)
|
||||
| set(values["prefix"].input_variables if values["prefix"] else [])
|
||||
- set(values["partial_variables"])
|
||||
self.input_variables = sorted(
|
||||
set(self.suffix.input_variables)
|
||||
| set(self.prefix.input_variables if self.prefix else [])
|
||||
- set(self.partial_variables)
|
||||
)
|
||||
return values
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra=Extra.forbid,
|
||||
)
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
if self.examples is not None:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.utils import image as image_utils
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ def _load_prompt_from_file(
|
||||
with open(file_path, encoding=encoding) as f:
|
||||
config = json.load(f)
|
||||
elif file_path.suffix.endswith((".yaml", ".yml")):
|
||||
with open(file_path, mode="r", encoding=encoding) as f:
|
||||
with open(file_path, encoding=encoding) as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Optional as Optional
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||
@@ -34,8 +36,9 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "pipeline"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_input_variables(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_input_variables(cls, values: Dict) -> Any:
|
||||
"""Get input variables."""
|
||||
created_variables = set()
|
||||
all_variables = set()
|
||||
@@ -106,3 +109,6 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
PipelinePromptTemplate.model_rebuild()
|
||||
|
||||
@@ -6,6 +6,8 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
@@ -13,7 +15,6 @@ from langchain_core.prompts.string import (
|
||||
get_template_variables,
|
||||
mustache_schema,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
@@ -73,8 +74,9 @@ class PromptTemplate(StringPromptTemplate):
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_init_validation(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_init_validation(cls, values: Dict) -> Any:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values.get("template") is None:
|
||||
# Will let pydantic fail with a ValidationError if template
|
||||
@@ -231,11 +233,13 @@ class PromptTemplate(StringPromptTemplate):
|
||||
Returns:
|
||||
The prompt loaded from the file.
|
||||
"""
|
||||
with open(str(template_file), "r", encoding=encoding) as f:
|
||||
with open(str(template_file), encoding=encoding) as f:
|
||||
template = f.read()
|
||||
if input_variables:
|
||||
warnings.warn(
|
||||
"`input_variables' is deprecated and ignored.", DeprecationWarning
|
||||
"`input_variables' is deprecated and ignored.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return cls.from_template(template=template, **kwargs)
|
||||
|
||||
|
||||
@@ -7,10 +7,11 @@ from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
import langchain_core.utils.mustache as mustache
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, create_model
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
@@ -40,14 +41,14 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
try:
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
"Please be cautious when using jinja2 templates. "
|
||||
"Do not expand jinja2 templates using unverified or user-controlled "
|
||||
"inputs as that can result in arbitrary Python code execution."
|
||||
)
|
||||
) from e
|
||||
|
||||
# This uses a sandboxed environment to prevent arbitrary code execution.
|
||||
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
||||
@@ -81,17 +82,17 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
warning_message += f"Extra variables: {extra_variables}"
|
||||
|
||||
if warning_message:
|
||||
warnings.warn(warning_message.strip())
|
||||
warnings.warn(warning_message.strip(), stacklevel=7)
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
)
|
||||
) from e
|
||||
env = Environment()
|
||||
ast = env.parse(template)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
|
||||
@@ -7,28 +7,25 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.language_models.base import BaseLanguageModel
|
||||
from langchain_core.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
MessageLikeRepresentation,
|
||||
MessagesPlaceholder,
|
||||
_convert_to_message,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
Runnable,
|
||||
RunnableSequence,
|
||||
RunnableSerializable,
|
||||
)
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
|
||||
|
||||
@beta()
|
||||
@@ -37,6 +34,26 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
|
||||
schema_: Union[Dict, Type[BaseModel]]
|
||||
"""Schema for the structured prompt."""
|
||||
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
*,
|
||||
structured_output_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
schema_ = schema_ or kwargs.pop("schema")
|
||||
structured_output_kwargs = structured_output_kwargs or {}
|
||||
for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
|
||||
structured_output_kwargs[k] = kwargs.pop(k)
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
schema_=schema_,
|
||||
structured_output_kwargs=structured_output_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
@@ -52,6 +69,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
cls,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
**kwargs: Any,
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
@@ -61,11 +79,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import StructuredPrompt
|
||||
|
||||
class OutputSchema(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
template = StructuredPrompt(
|
||||
[
|
||||
("human", "Hello, how are you?"),
|
||||
("ai", "I'm doing well, thanks!"),
|
||||
@@ -82,29 +102,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
(4) 2-tuple of (message class, template), (5) a string which is
|
||||
shorthand for ("human", template); e.g., "{user_input}"
|
||||
schema: a dictionary representation of function call, or a Pydantic model.
|
||||
kwargs: Any additional kwargs to pass through to
|
||||
``ChatModel.with_structured_output(schema, **kwargs)``.
|
||||
|
||||
Returns:
|
||||
a structured prompt template
|
||||
"""
|
||||
_messages = [_convert_to_message(message) for message in messages]
|
||||
|
||||
# Automatically infer input variables from messages
|
||||
input_vars: Set[str] = set()
|
||||
partial_vars: Dict[str, Any] = {}
|
||||
for _message in _messages:
|
||||
if isinstance(_message, MessagesPlaceholder) and _message.optional:
|
||||
partial_vars[_message.variable_name] = []
|
||||
elif isinstance(
|
||||
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
||||
):
|
||||
input_vars.update(_message.input_variables)
|
||||
|
||||
return cls(
|
||||
input_variables=sorted(input_vars),
|
||||
messages=_messages,
|
||||
partial_variables=partial_vars,
|
||||
schema_=schema,
|
||||
)
|
||||
return cls(messages, schema, **kwargs)
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
@@ -115,27 +119,16 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
) -> RunnableSerializable[Dict, Other]:
|
||||
if isinstance(other, BaseLanguageModel) or hasattr(
|
||||
other, "with_structured_output"
|
||||
):
|
||||
try:
|
||||
return RunnableSequence(
|
||||
self, other.with_structured_output(self.schema_)
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
raise NotImplementedError(
|
||||
"Structured prompts must be piped to a language model that "
|
||||
"implements with_structured_output."
|
||||
) from e
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Structured prompts must be piped to a language model that "
|
||||
"implements with_structured_output."
|
||||
)
|
||||
return self.pipe(other)
|
||||
|
||||
def pipe(
|
||||
self,
|
||||
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
|
||||
*others: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Callable[[Iterator[Any]], Iterator[Other]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
name: Optional[str] = None,
|
||||
) -> RunnableSerializable[Dict, Other]:
|
||||
"""Pipe the structured prompt to a language model.
|
||||
@@ -158,7 +151,9 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
):
|
||||
return RunnableSequence(
|
||||
self,
|
||||
others[0].with_structured_output(self.schema_),
|
||||
others[0].with_structured_output(
|
||||
self.schema_, **self.structured_output_kwargs
|
||||
),
|
||||
*others[1:],
|
||||
name=name,
|
||||
)
|
||||
|
||||
@@ -181,7 +181,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
the caller should try again later.
|
||||
"""
|
||||
with self._consume_lock:
|
||||
now = time.time()
|
||||
now = time.monotonic()
|
||||
|
||||
# initialize on first call to avoid a burst
|
||||
if self.last is None:
|
||||
|
||||
@@ -26,6 +26,9 @@ from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load.dump import dumpd
|
||||
@@ -50,6 +53,19 @@ RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
|
||||
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
|
||||
|
||||
|
||||
class LangSmithRetrieverParams(TypedDict, total=False):
|
||||
"""LangSmith parameters for tracing."""
|
||||
|
||||
ls_retriever_name: str
|
||||
"""Retriever name."""
|
||||
ls_vector_store_provider: Optional[str]
|
||||
"""Vector store provider."""
|
||||
ls_embedding_provider: Optional[str]
|
||||
"""Embedding provider."""
|
||||
ls_embedding_model: Optional[str]
|
||||
"""Embedding model."""
|
||||
|
||||
|
||||
class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
"""Abstract base class for a Document retrieval system.
|
||||
|
||||
@@ -111,8 +127,9 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
@@ -140,6 +157,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||
" instead of `get_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
swap = cls.get_relevant_documents
|
||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||
@@ -154,6 +172,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||
" instead of `aget_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
aswap = cls.aget_relevant_documents
|
||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||
@@ -167,6 +186,19 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||
)
|
||||
|
||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
|
||||
"""Get standard params for tracing."""
|
||||
|
||||
default_retriever_name = self.get_name()
|
||||
if default_retriever_name.startswith("Retriever"):
|
||||
default_retriever_name = default_retriever_name[9:]
|
||||
elif default_retriever_name.endswith("Retriever"):
|
||||
default_retriever_name = default_retriever_name[:-9]
|
||||
default_retriever_name = default_retriever_name.lower()
|
||||
|
||||
ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
|
||||
return ls_params
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
@@ -191,13 +223,17 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(**kwargs),
|
||||
}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
inheritable_metadata=inheritable_metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
@@ -250,13 +286,17 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(**kwargs),
|
||||
}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
inheritable_metadata=inheritable_metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
@@ -313,7 +353,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
run_manager=run_manager.get_sync(),
|
||||
)
|
||||
|
||||
@deprecated(since="0.1.46", alternative="invoke", removal="0.3.0")
|
||||
@deprecated(since="0.1.46", alternative="invoke", removal="1.0")
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
@@ -357,7 +397,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
config["run_name"] = run_name
|
||||
return self.invoke(query, config, **kwargs)
|
||||
|
||||
@deprecated(since="0.1.46", alternative="ainvoke", removal="0.3.0")
|
||||
@deprecated(since="0.1.46", alternative="ainvoke", removal="1.0")
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@@ -35,7 +35,8 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal, get_args
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
||||
from typing_extensions import Literal, get_args, get_type_hints
|
||||
|
||||
from langchain_core._api import beta_decorator
|
||||
from langchain_core.load.dump import dumpd
|
||||
@@ -44,7 +45,6 @@ from langchain_core.load.serializable import (
|
||||
SerializedConstructor,
|
||||
SerializedNotImplemented,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
_set_config_context,
|
||||
@@ -83,7 +83,6 @@ from langchain_core.runnables.utils import (
|
||||
)
|
||||
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
@@ -236,25 +235,56 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/
|
||||
""" # noqa: E501
|
||||
|
||||
name: Optional[str] = None
|
||||
name: Optional[str]
|
||||
"""The name of the Runnable. Used for debugging and tracing."""
|
||||
|
||||
def get_name(
|
||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get the name of the Runnable."""
|
||||
name = name or self.name or self.__class__.__name__
|
||||
if suffix:
|
||||
if name[0].isupper():
|
||||
return name + suffix.title()
|
||||
else:
|
||||
return name + "_" + suffix.lower()
|
||||
if name:
|
||||
name_ = name
|
||||
elif hasattr(self, "name") and self.name:
|
||||
name_ = self.name
|
||||
else:
|
||||
return name
|
||||
# Here we handle a case where the runnable subclass is also a pydantic
|
||||
# model.
|
||||
cls = self.__class__
|
||||
# Then it's a pydantic sub-class, and we have to check
|
||||
# whether it's a generic, and if so recover the original name.
|
||||
if (
|
||||
hasattr(
|
||||
cls,
|
||||
"__pydantic_generic_metadata__",
|
||||
)
|
||||
and "origin" in cls.__pydantic_generic_metadata__
|
||||
and cls.__pydantic_generic_metadata__["origin"] is not None
|
||||
):
|
||||
name_ = cls.__pydantic_generic_metadata__["origin"].__name__
|
||||
else:
|
||||
name_ = cls.__name__
|
||||
|
||||
if suffix:
|
||||
if name_[0].isupper():
|
||||
return name_ + suffix.title()
|
||||
else:
|
||||
return name_ + "_" + suffix.lower()
|
||||
else:
|
||||
return name_
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
"""The type of input this Runnable accepts specified as a type annotation."""
|
||||
# First loop through bases -- this will help generic
|
||||
# any pydantic models.
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) == 2:
|
||||
return metadata["args"][0]
|
||||
|
||||
# then loop through __orig_bases__ -- this will Runnables that do not inherit
|
||||
# from pydantic
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 2:
|
||||
@@ -268,6 +298,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
"""The type of output this Runnable produces specified as a type annotation."""
|
||||
# First loop through bases -- this will help generic
|
||||
# any pydantic models.
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) == 2:
|
||||
return metadata["args"][1]
|
||||
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 2:
|
||||
@@ -302,12 +340,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.InputType
|
||||
|
||||
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=(root_type, None),
|
||||
__root__=root_type,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -334,12 +372,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.OutputType
|
||||
|
||||
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=(root_type, None),
|
||||
__root__=root_type,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -381,15 +419,19 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
else None
|
||||
)
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.get_name("Config"),
|
||||
# Many need to create a typed dict instead to implement NotRequired!
|
||||
all_fields = {
|
||||
**({"configurable": (configurable, None)} if configurable else {}),
|
||||
**{
|
||||
field_name: (field_type, None)
|
||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||
for field_name, field_type in get_type_hints(RunnableConfig).items()
|
||||
if field_name in [i for i in include if i != "configurable"]
|
||||
},
|
||||
}
|
||||
model = create_model( # type: ignore[call-overload]
|
||||
self.get_name("Config"), **all_fields
|
||||
)
|
||||
return model
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
"""Return a graph representation of this Runnable."""
|
||||
@@ -579,7 +621,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
from langchain_core.runnables.passthrough import RunnableAssign
|
||||
|
||||
return self | RunnableAssign(RunnableParallel(kwargs))
|
||||
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@@ -2129,7 +2171,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
iterator_ = None
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
if accepts_config(transformer):
|
||||
@@ -2314,7 +2355,6 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
"""Runnable that can be serialized to JSON."""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""The name of the Runnable. Used for debugging and tracing."""
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
"""Serialize the Runnable to JSON.
|
||||
@@ -2369,10 +2409,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
from langchain_core.runnables.configurable import RunnableConfigurableFields
|
||||
|
||||
for key in kwargs:
|
||||
if key not in self.__fields__:
|
||||
if key not in self.model_fields:
|
||||
raise ValueError(
|
||||
f"Configuration key {key} not found in {self}: "
|
||||
f"available keys are {self.__fields__.keys()}"
|
||||
f"available keys are {self.model_fields.keys()}"
|
||||
)
|
||||
|
||||
return RunnableConfigurableFields(default=self, fields=kwargs)
|
||||
@@ -2447,13 +2487,13 @@ def _seq_input_schema(
|
||||
return first.get_input_schema(config)
|
||||
elif isinstance(first, RunnableAssign):
|
||||
next_input_schema = _seq_input_schema(steps[1:], config)
|
||||
if not next_input_schema.__custom_root_type__:
|
||||
if not issubclass(next_input_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceInput",
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in next_input_schema.__fields__.items()
|
||||
for k, v in next_input_schema.model_fields.items()
|
||||
if k not in first.mapper.steps__
|
||||
},
|
||||
)
|
||||
@@ -2474,36 +2514,36 @@ def _seq_output_schema(
|
||||
elif isinstance(last, RunnableAssign):
|
||||
mapper_output_schema = last.mapper.get_output_schema(config)
|
||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||
if not prev_output_schema.__custom_root_type__:
|
||||
if not issubclass(prev_output_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
**{
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in prev_output_schema.__fields__.items()
|
||||
for k, v in prev_output_schema.model_fields.items()
|
||||
},
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in mapper_output_schema.__fields__.items()
|
||||
for k, v in mapper_output_schema.model_fields.items()
|
||||
},
|
||||
},
|
||||
)
|
||||
elif isinstance(last, RunnablePick):
|
||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||
if not prev_output_schema.__custom_root_type__:
|
||||
if not issubclass(prev_output_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
if isinstance(last.keys, list):
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in prev_output_schema.__fields__.items()
|
||||
for k, v in prev_output_schema.model_fields.items()
|
||||
if k in last.keys
|
||||
},
|
||||
)
|
||||
else:
|
||||
field = prev_output_schema.__fields__[last.keys]
|
||||
field = prev_output_schema.model_fields[last.keys]
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
__root__=(field.annotation, field.default),
|
||||
@@ -2665,8 +2705,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
@@ -3208,8 +3249,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
else:
|
||||
final_pipeline = step.transform(final_pipeline, config)
|
||||
|
||||
for output in final_pipeline:
|
||||
yield output
|
||||
yield from final_pipeline
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
@@ -3403,8 +3443,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_name(
|
||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||
@@ -3451,7 +3492,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for step in self.steps__.values()
|
||||
for k, v in step.get_input_schema(config).__fields__.items()
|
||||
for k, v in step.get_input_schema(config).model_fields.items()
|
||||
if k != "__root__"
|
||||
},
|
||||
)
|
||||
@@ -3469,11 +3510,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
Returns:
|
||||
The output schema of the Runnable.
|
||||
"""
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.get_name("Output"),
|
||||
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
|
||||
)
|
||||
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
|
||||
return create_model(self.get_name("Output"), **fields)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
@@ -3883,6 +3921,8 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
atransform: Optional[
|
||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
|
||||
] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize a RunnableGenerator.
|
||||
|
||||
@@ -3910,13 +3950,13 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
try:
|
||||
self.name = func_for_name.__name__
|
||||
self.name = name or func_for_name.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
self.name = "RunnableGenerator"
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
first_param = next(iter(params.values()), None)
|
||||
@@ -3929,7 +3969,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
return (
|
||||
@@ -4153,7 +4193,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
"""The type of the input to this Runnable."""
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
first_param = next(iter(params.values()), None)
|
||||
@@ -4175,7 +4215,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
Returns:
|
||||
The input schema for this Runnable.
|
||||
"""
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
|
||||
if isinstance(func, itemgetter):
|
||||
# This is terrible, but afaict it's not possible to access _items
|
||||
@@ -4184,15 +4224,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if all(
|
||||
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
|
||||
):
|
||||
fields = {item[1:-1]: (Any, ...) for item in items}
|
||||
# It's a dict, lol
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||
)
|
||||
return create_model(self.get_name("Input"), **fields)
|
||||
else:
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=(List[Any], None),
|
||||
__root__=List[Any],
|
||||
)
|
||||
|
||||
if self.InputType != Any:
|
||||
@@ -4201,7 +4239,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||
**{key: (Any, ...) for key in dict_keys}, # type: ignore
|
||||
)
|
||||
|
||||
return super().get_input_schema(config)
|
||||
@@ -4213,7 +4251,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
Returns:
|
||||
The type of the output of this Runnable.
|
||||
"""
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
if sig.return_annotation != inspect.Signature.empty:
|
||||
@@ -4577,13 +4615,12 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
if hasattr(self, "func"):
|
||||
for output in self._transform_stream_with_config(
|
||||
yield from self._transform_stream_with_config(
|
||||
input,
|
||||
self._transform,
|
||||
self._config(config, self.func),
|
||||
**kwargs,
|
||||
):
|
||||
yield output
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Cannot stream a coroutine function synchronously."
|
||||
@@ -4730,8 +4767,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
@@ -4758,10 +4796,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
schema = self.bound.get_output_schema(config)
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=(
|
||||
List[schema], # type: ignore
|
||||
None,
|
||||
),
|
||||
__root__=List[schema], # type: ignore[valid-type]
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -4981,8 +5016,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
The type can be a pydantic model, or a type annotation (e.g., `List[str]`).
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -5318,7 +5354,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||
RunnableBindingBase.model_rebuild()
|
||||
|
||||
|
||||
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
@@ -5334,12 +5370,13 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
`RunnableWithFallbacks`) that add additional functionality.
|
||||
|
||||
These methods include:
|
||||
- `bind`: Bind kwargs to pass to the underlying Runnable when running it.
|
||||
- `with_config`: Bind config to pass to the underlying Runnable when running it.
|
||||
- `with_listeners`: Bind lifecycle listeners to the underlying Runnable.
|
||||
- `with_types`: Override the input and output types of the underlying Runnable.
|
||||
- `with_retry`: Bind a retry policy to the underlying Runnable.
|
||||
- `with_fallbacks`: Bind a fallback policy to the underlying Runnable.
|
||||
|
||||
- ``bind``: Bind kwargs to pass to the underlying Runnable when running it.
|
||||
- ``with_config``: Bind config to pass to the underlying Runnable when running it.
|
||||
- ``with_listeners``: Bind lifecycle listeners to the underlying Runnable.
|
||||
- ``with_types``: Override the input and output types of the underlying Runnable.
|
||||
- ``with_retry``: Bind a retry policy to the underlying Runnable.
|
||||
- ``with_fallbacks``: Bind a fallback policy to the underlying Runnable.
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -14,8 +14,9 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableLike,
|
||||
@@ -134,10 +135,21 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
runnable = coerce_to_runnable(runnable)
|
||||
_branches.append((condition, runnable))
|
||||
|
||||
super().__init__(branches=_branches, default=default_) # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
branches=_branches,
|
||||
default=default_,
|
||||
# Hard-coding a name here because RunnableBranch is a generic
|
||||
# and with pydantic 2, the class name with pydantic will capture
|
||||
# include the parameterized type, which is not what we want.
|
||||
# e.g., we'd get RunnableBranch[Input, Output] instead of RunnableBranch
|
||||
# for the name. This information is already captured in the
|
||||
# input and output types.
|
||||
name="RunnableBranch",
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
||||
@@ -236,6 +236,7 @@ def get_config_list(
|
||||
warnings.warn(
|
||||
"Provided run_id be used only for the first element of the batch.",
|
||||
category=RuntimeWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
subsequent = cast(
|
||||
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
|
||||
@@ -348,37 +349,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
base["callbacks"] = mngr
|
||||
else:
|
||||
# base_callbacks is also a manager
|
||||
manager = base_callbacks.__class__(
|
||||
parent_run_id=base_callbacks.parent_run_id
|
||||
or these_callbacks.parent_run_id,
|
||||
handlers=[],
|
||||
inheritable_handlers=[],
|
||||
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
|
||||
inheritable_tags=list(
|
||||
set(
|
||||
base_callbacks.inheritable_tags
|
||||
+ these_callbacks.inheritable_tags
|
||||
)
|
||||
),
|
||||
metadata={
|
||||
**base_callbacks.metadata,
|
||||
**these_callbacks.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
handlers = base_callbacks.handlers + these_callbacks.handlers
|
||||
inheritable_handlers = (
|
||||
base_callbacks.inheritable_handlers
|
||||
+ these_callbacks.inheritable_handlers
|
||||
)
|
||||
|
||||
for handler in handlers:
|
||||
manager.add_handler(handler)
|
||||
|
||||
for handler in inheritable_handlers:
|
||||
manager.add_handler(handler, inherit=True)
|
||||
|
||||
base["callbacks"] = manager
|
||||
base["callbacks"] = base_callbacks.merge(these_callbacks)
|
||||
elif key == "recursion_limit":
|
||||
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
|
||||
base["recursion_limit"] = config["recursion_limit"]
|
||||
|
||||
@@ -20,7 +20,8 @@ from typing import (
|
||||
)
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
@@ -58,8 +59,9 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
config: Optional[RunnableConfig] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -373,28 +375,33 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
Returns:
|
||||
List[ConfigurableFieldSpec]: The configuration specs.
|
||||
"""
|
||||
return get_unique_config_specs(
|
||||
[
|
||||
(
|
||||
# TODO(0.3): This change removes field_info which isn't needed in pydantic 2
|
||||
config_specs = []
|
||||
|
||||
for field_name, spec in self.fields.items():
|
||||
if isinstance(spec, ConfigurableField):
|
||||
config_specs.append(
|
||||
ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description
|
||||
or self.default.__fields__[field_name].field_info.description,
|
||||
or self.default.model_fields[field_name].description,
|
||||
annotation=spec.annotation
|
||||
or self.default.__fields__[field_name].annotation,
|
||||
or self.default.model_fields[field_name].annotation,
|
||||
default=getattr(self.default, field_name),
|
||||
is_shared=spec.is_shared,
|
||||
)
|
||||
if isinstance(spec, ConfigurableField)
|
||||
else make_options_spec(
|
||||
spec, self.default.__fields__[field_name].field_info.description
|
||||
)
|
||||
else:
|
||||
config_specs.append(
|
||||
make_options_spec(
|
||||
spec, self.default.model_fields[field_name].description
|
||||
)
|
||||
)
|
||||
for field_name, spec in self.fields.items()
|
||||
]
|
||||
+ list(self.default.config_specs)
|
||||
)
|
||||
|
||||
config_specs.extend(self.default.config_specs)
|
||||
|
||||
return get_unique_config_specs(config_specs)
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
@@ -436,7 +443,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
init_params = {
|
||||
k: v
|
||||
for k, v in self.default.__dict__.items()
|
||||
if k in self.default.__fields__
|
||||
if k in self.default.model_fields
|
||||
}
|
||||
return (
|
||||
self.default.__class__(**{**init_params, **configurable}),
|
||||
|
||||
@@ -18,8 +18,9 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
@@ -107,8 +108,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
will not be passed to fallbacks. If used, the base Runnable and its fallbacks
|
||||
must accept a dictionary as input."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
@@ -180,6 +182,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
output = context.run(
|
||||
runnable.invoke,
|
||||
input,
|
||||
config,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
|
||||
@@ -22,8 +22,9 @@ from typing import (
|
||||
)
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable as RunnableType
|
||||
@@ -235,7 +236,9 @@ def node_data_json(
|
||||
json = (
|
||||
{
|
||||
"type": "schema",
|
||||
"data": node.data.schema(),
|
||||
"data": node.data.model_json_schema(
|
||||
schema_generator=_IgnoreUnserializable
|
||||
),
|
||||
}
|
||||
if with_schemas
|
||||
else {
|
||||
@@ -537,7 +540,7 @@ class Graph:
|
||||
*,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
node_colors: Optional[NodeStyles] = None,
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draw the graph as a Mermaid syntax string.
|
||||
@@ -573,7 +576,7 @@ class Graph:
|
||||
self,
|
||||
*,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
node_colors: Optional[NodeStyles] = None,
|
||||
wrap_label_n_words: int = 9,
|
||||
output_file_path: Optional[str] = None,
|
||||
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
||||
|
||||
@@ -20,7 +20,7 @@ def draw_mermaid(
|
||||
last_node: Optional[str] = None,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_styles: NodeStyles = NodeStyles(),
|
||||
node_styles: Optional[NodeStyles] = None,
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draws a Mermaid graph using the provided graph data.
|
||||
@@ -78,51 +78,82 @@ def draw_mermaid(
|
||||
)
|
||||
mermaid_graph += f"\t{node_label}\n"
|
||||
|
||||
subgraph = ""
|
||||
# Add edges to the graph
|
||||
# Group edges by their common prefixes
|
||||
edge_groups: Dict[str, List[Edge]] = {}
|
||||
for edge in edges:
|
||||
src_prefix = edge.source.split(":")[0] if ":" in edge.source else None
|
||||
tgt_prefix = edge.target.split(":")[0] if ":" in edge.target else None
|
||||
# exit subgraph if source or target is not in the same subgraph
|
||||
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
|
||||
mermaid_graph += "\tend\n"
|
||||
subgraph = ""
|
||||
# enter subgraph if source and target are in the same subgraph
|
||||
if not subgraph and src_prefix and src_prefix == tgt_prefix:
|
||||
mermaid_graph += f"\tsubgraph {src_prefix}\n"
|
||||
subgraph = src_prefix
|
||||
|
||||
source, target = edge.source, edge.target
|
||||
|
||||
# Add BR every wrap_label_n_words words
|
||||
if edge.data is not None:
|
||||
edge_data = edge.data
|
||||
words = str(edge_data).split() # Split the string into words
|
||||
# Group words into chunks of wrap_label_n_words size
|
||||
if len(words) > wrap_label_n_words:
|
||||
edge_data = " <br> ".join(
|
||||
" ".join(words[i : i + wrap_label_n_words])
|
||||
for i in range(0, len(words), wrap_label_n_words)
|
||||
)
|
||||
if edge.conditional:
|
||||
edge_label = f" -.  {edge_data}  .-> "
|
||||
else:
|
||||
edge_label = f" --  {edge_data}  --> "
|
||||
else:
|
||||
if edge.conditional:
|
||||
edge_label = " -.-> "
|
||||
else:
|
||||
edge_label = " --> "
|
||||
mermaid_graph += (
|
||||
f"\t{_escape_node_label(source)}{edge_label}"
|
||||
f"{_escape_node_label(target)};\n"
|
||||
src_parts = edge.source.split(":")
|
||||
tgt_parts = edge.target.split(":")
|
||||
common_prefix = ":".join(
|
||||
src for src, tgt in zip(src_parts, tgt_parts) if src == tgt
|
||||
)
|
||||
if subgraph:
|
||||
mermaid_graph += "end\n"
|
||||
edge_groups.setdefault(common_prefix, []).append(edge)
|
||||
|
||||
seen_subgraphs = set()
|
||||
|
||||
def add_subgraph(edges: List[Edge], prefix: str) -> None:
|
||||
nonlocal mermaid_graph
|
||||
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
|
||||
if prefix and not self_loop:
|
||||
subgraph = prefix.split(":")[-1]
|
||||
if subgraph in seen_subgraphs:
|
||||
raise ValueError(
|
||||
f"Found duplicate subgraph '{subgraph}' -- this likely means that "
|
||||
"you're reusing a subgraph node with the same name. "
|
||||
"Please adjust your graph to have subgraph nodes with unique names."
|
||||
)
|
||||
|
||||
seen_subgraphs.add(subgraph)
|
||||
mermaid_graph += f"\tsubgraph {subgraph}\n"
|
||||
|
||||
for edge in edges:
|
||||
source, target = edge.source, edge.target
|
||||
|
||||
# Add BR every wrap_label_n_words words
|
||||
if edge.data is not None:
|
||||
edge_data = edge.data
|
||||
words = str(edge_data).split() # Split the string into words
|
||||
# Group words into chunks of wrap_label_n_words size
|
||||
if len(words) > wrap_label_n_words:
|
||||
edge_data = " <br> ".join(
|
||||
" ".join(words[i : i + wrap_label_n_words])
|
||||
for i in range(0, len(words), wrap_label_n_words)
|
||||
)
|
||||
if edge.conditional:
|
||||
edge_label = f" -.  {edge_data}  .-> "
|
||||
else:
|
||||
edge_label = f" --  {edge_data}  --> "
|
||||
else:
|
||||
if edge.conditional:
|
||||
edge_label = " -.-> "
|
||||
else:
|
||||
edge_label = " --> "
|
||||
|
||||
mermaid_graph += (
|
||||
f"\t{_escape_node_label(source)}{edge_label}"
|
||||
f"{_escape_node_label(target)};\n"
|
||||
)
|
||||
|
||||
# Recursively add nested subgraphs
|
||||
for nested_prefix in edge_groups.keys():
|
||||
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
|
||||
continue
|
||||
add_subgraph(edge_groups[nested_prefix], nested_prefix)
|
||||
|
||||
if prefix and not self_loop:
|
||||
mermaid_graph += "\tend\n"
|
||||
|
||||
# Start with the top-level edges (no common prefix)
|
||||
add_subgraph(edge_groups.get("", []), "")
|
||||
|
||||
# Add remaining subgraphs
|
||||
for prefix in edge_groups.keys():
|
||||
if ":" in prefix or prefix == "":
|
||||
continue
|
||||
add_subgraph(edge_groups[prefix], prefix)
|
||||
|
||||
# Add custom styles for nodes
|
||||
if with_styles:
|
||||
mermaid_graph += _generate_mermaid_graph_styles(node_styles)
|
||||
mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles())
|
||||
return mermaid_graph
|
||||
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.load.load import load
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import (
|
||||
@@ -320,17 +320,22 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
history_chain = RunnablePassthrough.assign(
|
||||
**{messages_key: history_chain}
|
||||
).with_config(run_name="insert_history")
|
||||
|
||||
runnable_sync: Runnable = runnable.with_listeners(on_end=self._exit_history)
|
||||
runnable_async: Runnable = runnable.with_alisteners(on_end=self._aexit_history)
|
||||
|
||||
def _call_runnable_sync(_input: Any) -> Runnable:
|
||||
return runnable_sync
|
||||
|
||||
async def _call_runnable_async(_input: Any) -> Runnable:
|
||||
return runnable_async
|
||||
|
||||
bound: Runnable = (
|
||||
history_chain
|
||||
| RunnableBranch(
|
||||
(
|
||||
RunnableLambda(
|
||||
self._is_not_async, afunc=self._is_async
|
||||
).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"),
|
||||
runnable.with_alisteners(on_end=self._aexit_history),
|
||||
),
|
||||
runnable.with_listeners(on_end=self._exit_history),
|
||||
)
|
||||
| RunnableLambda(
|
||||
_call_runnable_sync,
|
||||
_call_runnable_async,
|
||||
).with_config(run_name="check_sync_or_async")
|
||||
).with_config(run_name="RunnableWithMessageHistory")
|
||||
|
||||
if history_factory_config:
|
||||
@@ -368,28 +373,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
super_schema = super().get_input_schema(config)
|
||||
if super_schema.__custom_root_type__ or not super_schema.schema().get(
|
||||
"properties"
|
||||
):
|
||||
from langchain_core.messages import BaseMessage
|
||||
# TODO(0.3): Verify that this change was correct
|
||||
# Not enough tests and unclear on why the previous implementation was
|
||||
# necessary.
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
fields: Dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
...,
|
||||
)
|
||||
elif self.input_messages_key:
|
||||
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
|
||||
else:
|
||||
fields["__root__"] = (Sequence[BaseMessage], ...)
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableWithChatHistoryInput",
|
||||
**fields,
|
||||
fields: Dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
...,
|
||||
)
|
||||
elif self.input_messages_key:
|
||||
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
|
||||
else:
|
||||
return super_schema
|
||||
fields["__root__"] = (Sequence[BaseMessage], ...)
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableWithChatHistoryInput",
|
||||
**fields,
|
||||
)
|
||||
|
||||
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
|
||||
return False
|
||||
@@ -429,7 +431,9 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
# This occurs for chat models - since we batch inputs
|
||||
if isinstance(input_val[0], list):
|
||||
if len(input_val) != 1:
|
||||
raise ValueError()
|
||||
raise ValueError(
|
||||
f"Expected a single list of messages. Got {input_val}."
|
||||
)
|
||||
return input_val[0]
|
||||
return list(input_val)
|
||||
else:
|
||||
@@ -468,7 +472,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
elif isinstance(output_val, (list, tuple)):
|
||||
return list(output_val)
|
||||
else:
|
||||
raise ValueError()
|
||||
raise ValueError(
|
||||
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
|
||||
f"Got {output_val}."
|
||||
)
|
||||
|
||||
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
|
||||
@@ -21,7 +21,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel, RootModel
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
Runnable,
|
||||
@@ -216,7 +217,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
||||
],
|
||||
],
|
||||
) -> "RunnableAssign":
|
||||
) -> RunnableAssign:
|
||||
"""Merge the Dict input with the output produced by the mapping argument.
|
||||
|
||||
Args:
|
||||
@@ -227,7 +228,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
A Runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
|
||||
def invoke(
|
||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
@@ -419,7 +420,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
if not map_input_schema.__custom_root_type__:
|
||||
if not issubclass(map_input_schema, RootModel):
|
||||
# ie. it's a dict
|
||||
return map_input_schema
|
||||
|
||||
@@ -430,20 +431,22 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
map_output_schema = self.mapper.get_output_schema(config)
|
||||
if (
|
||||
not map_input_schema.__custom_root_type__
|
||||
and not map_output_schema.__custom_root_type__
|
||||
if not issubclass(map_input_schema, RootModel) and not issubclass(
|
||||
map_output_schema, RootModel
|
||||
):
|
||||
# ie. both are dicts
|
||||
fields = {}
|
||||
|
||||
for name, field_info in map_input_schema.model_fields.items():
|
||||
fields[name] = (field_info.annotation, field_info.default)
|
||||
|
||||
for name, field_info in map_output_schema.model_fields.items():
|
||||
fields[name] = (field_info.annotation, field_info.default)
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableAssignOutput",
|
||||
**{
|
||||
k: (v.type_, v.default)
|
||||
for s in (map_input_schema, map_output_schema)
|
||||
for k, v in s.__fields__.items()
|
||||
},
|
||||
**fields,
|
||||
)
|
||||
elif not map_output_schema.__custom_root_type__:
|
||||
elif not issubclass(map_output_schema, RootModel):
|
||||
# ie. only map output is a dict
|
||||
# ie. input type is either unknown or inferred incorrectly
|
||||
return map_output_schema
|
||||
|
||||
@@ -147,7 +147,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
retry_state: RetryCallState,
|
||||
) -> RunnableConfig:
|
||||
attempt = retry_state.attempt_number
|
||||
tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None
|
||||
tag = f"retry:attempt:{attempt}" if attempt > 1 else None
|
||||
return patch_config(config, callbacks=run_manager.get_child(tag))
|
||||
|
||||
def _patch_config_list(
|
||||
@@ -218,6 +218,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: List[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
@@ -247,9 +249,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
if result is not_set:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
@@ -284,6 +284,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: List[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
@@ -313,9 +315,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
if result is not_set:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
@@ -83,8 +84,9 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
||||
@@ -28,12 +28,18 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, RootModel # pydantic: ignore
|
||||
from pydantic import create_model as _create_model_base # pydantic :ignore
|
||||
from pydantic.json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaMode,
|
||||
)
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
|
||||
from langchain_core.pydantic_v1 import create_model as _create_model_base
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
|
||||
Input = TypeVar("Input", contravariant=True)
|
||||
@@ -350,7 +356,7 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = IsFunctionArgDict()
|
||||
visitor.visit(tree)
|
||||
return list(visitor.keys) if visitor.keys else None
|
||||
return sorted(visitor.keys) if visitor.keys else None
|
||||
except (SyntaxError, TypeError, OSError, SystemError):
|
||||
return None
|
||||
|
||||
@@ -393,7 +399,9 @@ def get_function_nonlocals(func: Callable) -> List[Any]:
|
||||
visitor = FunctionNonLocals()
|
||||
visitor.visit(tree)
|
||||
values: List[Any] = []
|
||||
for k, v in inspect.getclosurevars(func).nonlocals.items():
|
||||
closure = inspect.getclosurevars(func)
|
||||
candidates = {**closure.globals, **closure.nonlocals}
|
||||
for k, v in candidates.items():
|
||||
if k in visitor.nonlocals:
|
||||
values.append(v)
|
||||
for kk in visitor.nonlocals:
|
||||
@@ -697,9 +705,57 @@ class _RootEventFilter:
|
||||
return include
|
||||
|
||||
|
||||
class _SchemaConfig(BaseConfig):
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
_SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
||||
|
||||
NO_DEFAULT = object()
|
||||
|
||||
|
||||
def create_base_class(
|
||||
name: str, type_: Any, default_: object = NO_DEFAULT
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a base class."""
|
||||
|
||||
def schema(
|
||||
cls: Type[BaseModel],
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
) -> Dict[str, Any]:
|
||||
# Complains about schema not being defined in superclass
|
||||
schema_ = super(cls, cls).schema( # type: ignore[misc]
|
||||
by_alias=by_alias, ref_template=ref_template
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
def model_json_schema(
|
||||
cls: Type[BaseModel],
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
mode: JsonSchemaMode = "validation",
|
||||
) -> Dict[str, Any]:
|
||||
# Complains about model_json_schema not being defined in superclass
|
||||
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
|
||||
by_alias=by_alias,
|
||||
ref_template=ref_template,
|
||||
schema_generator=schema_generator,
|
||||
mode=mode,
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
base_class_attributes = {
|
||||
"__annotations__": {"root": type_},
|
||||
"model_config": ConfigDict(arbitrary_types_allowed=True),
|
||||
"schema": classmethod(schema),
|
||||
"model_json_schema": classmethod(model_json_schema),
|
||||
"__module__": "langchain_core.runnables.utils",
|
||||
}
|
||||
|
||||
if default_ is not NO_DEFAULT:
|
||||
base_class_attributes["root"] = default_
|
||||
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
||||
return cast(Type[BaseModel], custom_root_type)
|
||||
|
||||
|
||||
def create_model(
|
||||
@@ -715,6 +771,21 @@ def create_model(
|
||||
Returns:
|
||||
Type[BaseModel]: The created model.
|
||||
"""
|
||||
|
||||
# Move this to caching path
|
||||
if "__root__" in field_definitions:
|
||||
if len(field_definitions) > 1:
|
||||
raise NotImplementedError(
|
||||
"When specifying __root__ no other "
|
||||
f"fields should be provided. Got {field_definitions}"
|
||||
)
|
||||
|
||||
arg = field_definitions["__root__"]
|
||||
if isinstance(arg, tuple):
|
||||
named_root_model = create_base_class(__model_name, arg[0], arg[1])
|
||||
else:
|
||||
named_root_model = create_base_class(__model_name, arg)
|
||||
return named_root_model
|
||||
try:
|
||||
return _create_model_cached(__model_name, **field_definitions)
|
||||
except TypeError:
|
||||
@@ -748,7 +819,7 @@ def is_async_generator(
|
||||
"""
|
||||
return (
|
||||
inspect.isasyncgenfunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
or hasattr(func, "__call__") # noqa: B004
|
||||
and inspect.isasyncgenfunction(func.__call__)
|
||||
)
|
||||
|
||||
@@ -767,6 +838,6 @@ def is_async_callable(
|
||||
"""
|
||||
return (
|
||||
asyncio.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
or hasattr(func, "__call__") # noqa: B004
|
||||
and asyncio.iscoroutinefunction(func.__call__)
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Visitor(ABC):
|
||||
@@ -127,7 +127,8 @@ class Comparison(FilterDirective):
|
||||
def __init__(
|
||||
self, comparator: Comparator, attribute: str, value: Any, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
comparator=comparator, attribute=attribute, value=value, **kwargs
|
||||
)
|
||||
|
||||
@@ -145,8 +146,11 @@ class Operation(FilterDirective):
|
||||
|
||||
def __init__(
|
||||
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any
|
||||
):
|
||||
super().__init__(operator=operator, arguments=arguments, **kwargs)
|
||||
) -> None:
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
operator=operator, arguments=arguments, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class StructuredQuery(Expr):
|
||||
@@ -165,5 +169,8 @@ class StructuredQuery(Expr):
|
||||
filter: Optional[FilterDirective],
|
||||
limit: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(query=query, filter=filter, limit=limit, **kwargs)
|
||||
) -> None:
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
query=query, filter=filter, limit=limit, **kwargs
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -516,15 +516,19 @@ class _TracerCore(ABC):
|
||||
|
||||
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""End a trace for a run."""
|
||||
return None
|
||||
|
||||
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process a run upon creation."""
|
||||
return None
|
||||
|
||||
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process a run upon update."""
|
||||
return None
|
||||
|
||||
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
@@ -533,39 +537,52 @@ class _TracerCore(ABC):
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process new LLM token."""
|
||||
return None
|
||||
|
||||
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run."""
|
||||
return None
|
||||
|
||||
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run upon error."""
|
||||
return None
|
||||
|
||||
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run."""
|
||||
return None
|
||||
|
||||
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run upon error."""
|
||||
return None
|
||||
|
||||
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run."""
|
||||
return None
|
||||
|
||||
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run upon error."""
|
||||
return None
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chat Model Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run."""
|
||||
return None
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run upon error."""
|
||||
return None
|
||||
|
||||
@@ -144,11 +144,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
example_id = str(run.reference_example_id)
|
||||
with self.lock:
|
||||
for res in eval_results:
|
||||
run_id = (
|
||||
str(getattr(res, "target_run_id"))
|
||||
if hasattr(res, "target_run_id")
|
||||
else str(run.id)
|
||||
)
|
||||
run_id = str(getattr(res, "target_run_id", run.id))
|
||||
self.logged_eval_results.setdefault((run_id, example_id), []).append(
|
||||
res
|
||||
)
|
||||
@@ -179,11 +175,9 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
source_info_: Dict[str, Any] = {}
|
||||
if res.evaluator_info:
|
||||
source_info_ = {**res.evaluator_info, **source_info_}
|
||||
run_id_ = (
|
||||
getattr(res, "target_run_id")
|
||||
if hasattr(res, "target_run_id") and res.target_run_id is not None
|
||||
else run.id
|
||||
)
|
||||
run_id_ = getattr(res, "target_run_id", None)
|
||||
if run_id_ is None:
|
||||
run_id_ = run.id
|
||||
self.client.create_feedback(
|
||||
run_id_,
|
||||
res.key,
|
||||
|
||||
@@ -11,22 +11,22 @@ from langsmith.schemas import RunBase as BaseRunV2
|
||||
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Use string instead.", removal="0.3.0")
|
||||
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
|
||||
def RunTypeEnum() -> Type[RunTypeEnumDep]:
|
||||
"""RunTypeEnum."""
|
||||
warnings.warn(
|
||||
"RunTypeEnum is deprecated. Please directly use a string instead"
|
||||
" (e.g. 'llm', 'chain', 'tool').",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return RunTypeEnumDep
|
||||
|
||||
|
||||
@deprecated("0.1.0", removal="0.3.0")
|
||||
@deprecated("0.1.0", removal="1.0")
|
||||
class TracerSessionV1Base(BaseModel):
|
||||
"""Base class for TracerSessionV1."""
|
||||
|
||||
@@ -35,33 +35,33 @@ class TracerSessionV1Base(BaseModel):
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@deprecated("0.1.0", removal="0.3.0")
|
||||
@deprecated("0.1.0", removal="1.0")
|
||||
class TracerSessionV1Create(TracerSessionV1Base):
|
||||
"""Create class for TracerSessionV1."""
|
||||
|
||||
|
||||
@deprecated("0.1.0", removal="0.3.0")
|
||||
@deprecated("0.1.0", removal="1.0")
|
||||
class TracerSessionV1(TracerSessionV1Base):
|
||||
"""TracerSessionV1 schema."""
|
||||
|
||||
id: int
|
||||
|
||||
|
||||
@deprecated("0.1.0", removal="0.3.0")
|
||||
@deprecated("0.1.0", removal="1.0")
|
||||
class TracerSessionBase(TracerSessionV1Base):
|
||||
"""Base class for TracerSession."""
|
||||
|
||||
tenant_id: UUID
|
||||
|
||||
|
||||
@deprecated("0.1.0", removal="0.3.0")
|
||||
@deprecated("0.1.0", removal="1.0")
|
||||
class TracerSession(TracerSessionBase):
|
||||
"""TracerSessionV1 schema for the V2 API."""
|
||||
|
||||
id: UUID
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Run", removal="0.3.0")
|
||||
@deprecated("0.1.0", alternative="Run", removal="1.0")
|
||||
class BaseRun(BaseModel):
|
||||
"""Base class for Run."""
|
||||
|
||||
@@ -77,15 +77,16 @@ class BaseRun(BaseModel):
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Run", removal="0.3.0")
|
||||
@deprecated("0.1.0", alternative="Run", removal="1.0")
|
||||
class LLMRun(BaseRun):
|
||||
"""Class for LLMRun."""
|
||||
|
||||
prompts: List[str]
|
||||
response: Optional[LLMResult] = None
|
||||
# Temporarily, remove but we will completely remove LLMRun
|
||||
# response: Optional[LLMResult] = None
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Run", removal="0.3.0")
|
||||
@deprecated("0.1.0", alternative="Run", removal="1.0")
|
||||
class ChainRun(BaseRun):
|
||||
"""Class for ChainRun."""
|
||||
|
||||
@@ -96,7 +97,7 @@ class ChainRun(BaseRun):
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Run", removal="0.3.0")
|
||||
@deprecated("0.1.0", alternative="Run", removal="1.0")
|
||||
class ToolRun(BaseRun):
|
||||
"""Class for ToolRun."""
|
||||
|
||||
|
||||
@@ -22,10 +22,12 @@ from langchain_core.utils.utils import (
|
||||
build_extra_kwargs,
|
||||
check_package_version,
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
mock_now,
|
||||
raise_for_status_with_text,
|
||||
secret_from_env,
|
||||
xor_args,
|
||||
)
|
||||
|
||||
@@ -54,4 +56,6 @@ __all__ = [
|
||||
"pre_init",
|
||||
"batch_iterate",
|
||||
"abatch_iterate",
|
||||
"from_env",
|
||||
"secret_from_env",
|
||||
]
|
||||
|
||||
@@ -41,6 +41,19 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
|
||||
" but with a different type."
|
||||
)
|
||||
elif isinstance(merged[right_k], str):
|
||||
# TODO: Add below special handling for 'type' key in 0.3 and remove
|
||||
# merge_lists 'type' logic.
|
||||
#
|
||||
# if right_k == "type":
|
||||
# if merged[right_k] == right_v:
|
||||
# continue
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# "Unable to merge. Two different values seen for special "
|
||||
# f"key 'type': {merged[right_k]} and {right_v}. 'type' "
|
||||
# "should either occur once or have the same value across "
|
||||
# "all dicts."
|
||||
# )
|
||||
merged[right_k] += right_v
|
||||
elif isinstance(merged[right_k], dict):
|
||||
merged[right_k] = merge_dicts(merged[right_k], right_v)
|
||||
@@ -81,10 +94,10 @@ def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]
|
||||
if e_left["index"] == e["index"]
|
||||
]
|
||||
if to_merge:
|
||||
# If a top-level "type" has been set for a chunk, it should no
|
||||
# longer be overridden by the "type" field in future chunks.
|
||||
if "type" in merged[to_merge[0]] and "type" in e:
|
||||
e.pop("type")
|
||||
# TODO: Remove this once merge_dict is updated with special
|
||||
# handling for 'type'.
|
||||
if "type" in e:
|
||||
e = {k: v for k, v in e.items() if k != "type"}
|
||||
merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e)
|
||||
else:
|
||||
merged.append(e)
|
||||
|
||||
@@ -62,8 +62,8 @@ def py_anext(
|
||||
__anext__ = cast(
|
||||
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
|
||||
)
|
||||
except AttributeError:
|
||||
raise TypeError(f"{iterator!r} is not an async iterator")
|
||||
except AttributeError as e:
|
||||
raise TypeError(f"{iterator!r} is not an async iterator") from e
|
||||
|
||||
if default is _no_default:
|
||||
return __anext__(iterator)
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from typing import (
|
||||
@@ -22,11 +23,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
@@ -81,10 +82,10 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="0.3.0",
|
||||
removal="1.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_function(
|
||||
model: Type[BaseModel],
|
||||
model: Type,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
@@ -108,7 +109,10 @@ def convert_pydantic_to_openai_function(
|
||||
else:
|
||||
schema = model.schema() # Pydantic 1
|
||||
schema = dereference_refs(schema)
|
||||
schema.pop("definitions", None)
|
||||
if "definitions" in schema: # pydantic 1
|
||||
schema.pop("definitions", None)
|
||||
if "$defs" in schema: # pydantic 2
|
||||
schema.pop("$defs", None)
|
||||
title = schema.pop("title", "")
|
||||
default_description = schema.pop("description", "")
|
||||
return {
|
||||
@@ -121,7 +125,7 @@ def convert_pydantic_to_openai_function(
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_tool()",
|
||||
removal="0.3.0",
|
||||
removal="1.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_tool(
|
||||
model: Type[BaseModel],
|
||||
@@ -155,7 +159,7 @@ def _get_python_function_name(function: Callable) -> str:
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="0.3.0",
|
||||
removal="1.0",
|
||||
)
|
||||
def convert_python_function_to_openai_function(
|
||||
function: Callable,
|
||||
@@ -172,10 +176,10 @@ def convert_python_function_to_openai_function(
|
||||
Returns:
|
||||
The OpenAI function description.
|
||||
"""
|
||||
from langchain_core import tools
|
||||
from langchain_core.tools.base import create_schema_from_function
|
||||
|
||||
func_name = _get_python_function_name(function)
|
||||
model = tools.create_schema_from_function(
|
||||
model = create_schema_from_function(
|
||||
func_name,
|
||||
function,
|
||||
filter_args=(),
|
||||
@@ -192,11 +196,13 @@ def convert_python_function_to_openai_function(
|
||||
|
||||
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
|
||||
visited: Dict = {}
|
||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||
|
||||
model = cast(
|
||||
Type[BaseModel],
|
||||
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
||||
)
|
||||
return convert_pydantic_to_openai_function(model)
|
||||
return convert_pydantic_to_openai_function(model) # type: ignore
|
||||
|
||||
|
||||
_MAX_TYPED_DICT_RECURSION = 25
|
||||
@@ -208,6 +214,9 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
visited: Dict,
|
||||
depth: int = 0,
|
||||
) -> Type:
|
||||
from pydantic.v1 import Field as Field_v1 # pydantic: ignore
|
||||
from pydantic.v1 import create_model as create_model_v1 # pydantic: ignore
|
||||
|
||||
if type_ in visited:
|
||||
return visited[type_]
|
||||
elif depth >= _MAX_TYPED_DICT_RECURSION:
|
||||
@@ -241,7 +250,7 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
field_kwargs["description"] = arg_desc
|
||||
else:
|
||||
pass
|
||||
fields[arg] = (new_arg_type, Field(**field_kwargs))
|
||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||
else:
|
||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||
arg_type, depth=depth + 1, visited=visited
|
||||
@@ -249,8 +258,8 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
field_kwargs = {"default": ...}
|
||||
if arg_desc := arg_descriptions.get(arg):
|
||||
field_kwargs["description"] = arg_desc
|
||||
fields[arg] = (new_arg_type, Field(**field_kwargs))
|
||||
model = create_model(typed_dict.__name__, **fields)
|
||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||
model = create_model_v1(typed_dict.__name__, **fields)
|
||||
model.__doc__ = description
|
||||
visited[typed_dict] = model
|
||||
return model
|
||||
@@ -268,7 +277,7 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="0.3.0",
|
||||
removal="1.0",
|
||||
)
|
||||
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tool into the OpenAI function API.
|
||||
@@ -305,7 +314,7 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_tool()",
|
||||
removal="0.3.0",
|
||||
removal="1.0",
|
||||
)
|
||||
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
"""Format tool into the OpenAI function API.
|
||||
@@ -327,15 +336,23 @@ def convert_to_openai_function(
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI function.
|
||||
|
||||
.. versionchanged:: 0.2.29
|
||||
|
||||
``strict`` arg added.
|
||||
|
||||
Args:
|
||||
function: A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain
|
||||
function:
|
||||
A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain
|
||||
Tool object, or a Python function. If a dictionary is passed in, it is
|
||||
assumed to already be a valid OpenAI function or a JSON schema with
|
||||
top-level 'title' and 'description' keys specified.
|
||||
strict: If True, model output is guaranteed to exactly match the JSON Schema
|
||||
strict:
|
||||
If True, model output is guaranteed to exactly match the JSON Schema
|
||||
provided in the function definition. If None, ``strict`` argument will not
|
||||
be included in function definition.
|
||||
|
||||
.. versionadded:: 0.2.29
|
||||
|
||||
Returns:
|
||||
A dict version of the passed in function which is compatible with the OpenAI
|
||||
function-calling API.
|
||||
@@ -380,9 +397,13 @@ def convert_to_openai_function(
|
||||
|
||||
if strict is not None:
|
||||
oai_function["strict"] = strict
|
||||
# As of 08/06/24, OpenAI requires that additionalProperties be supplied and set
|
||||
# to False if strict is True.
|
||||
oai_function["parameters"]["additionalProperties"] = False
|
||||
if strict:
|
||||
# As of 08/06/24, OpenAI requires that additionalProperties be supplied and
|
||||
# set to False if strict is True.
|
||||
# All properties layer needs 'additionalProperties=False'
|
||||
oai_function["parameters"] = _recursive_set_additional_properties_false(
|
||||
oai_function["parameters"]
|
||||
)
|
||||
return oai_function
|
||||
|
||||
|
||||
@@ -393,18 +414,26 @@ def convert_to_openai_tool(
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI tool.
|
||||
|
||||
.. versionchanged:: 0.2.29
|
||||
|
||||
``strict`` arg added.
|
||||
|
||||
Args:
|
||||
tool: Either a dictionary, a pydantic.BaseModel class, Python function, or
|
||||
tool:
|
||||
Either a dictionary, a pydantic.BaseModel class, Python function, or
|
||||
BaseTool. If a dictionary is passed in, it is assumed to already be a valid
|
||||
OpenAI tool, OpenAI function, or a JSON schema with top-level 'title' and
|
||||
'description' keys specified.
|
||||
strict: If True, model output is guaranteed to exactly match the JSON Schema
|
||||
strict:
|
||||
If True, model output is guaranteed to exactly match the JSON Schema
|
||||
provided in the function definition. If None, ``strict`` argument will not
|
||||
be included in tool definition.
|
||||
|
||||
.. versionadded:: 0.2.29
|
||||
|
||||
Returns:
|
||||
A dict version of the passed in tool which is compatible with the
|
||||
OpenAI tool-calling API.
|
||||
OpenAI tool-calling API.
|
||||
"""
|
||||
if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool:
|
||||
return tool
|
||||
@@ -559,6 +588,10 @@ def _parse_google_docstring(
|
||||
|
||||
|
||||
def _py_38_safe_origin(origin: Type) -> Type:
|
||||
origin_union_type_map: Dict[Type, Any] = (
|
||||
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
||||
)
|
||||
|
||||
origin_map: Dict[Type, Any] = {
|
||||
dict: Dict,
|
||||
list: List,
|
||||
@@ -568,5 +601,23 @@ def _py_38_safe_origin(origin: Type) -> Type:
|
||||
collections.abc.Mapping: typing.Mapping,
|
||||
collections.abc.Sequence: typing.Sequence,
|
||||
collections.abc.MutableMapping: typing.MutableMapping,
|
||||
**origin_union_type_map,
|
||||
}
|
||||
return cast(Type, origin_map.get(origin, origin))
|
||||
|
||||
|
||||
def _recursive_set_additional_properties_false(
|
||||
schema: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(schema, dict):
|
||||
# Check if 'required' is a key at the current level
|
||||
if "required" in schema:
|
||||
schema["additionalProperties"] = False
|
||||
# Recursively check 'properties' and 'items' if they exist
|
||||
if "properties" in schema:
|
||||
for value in schema["properties"].values():
|
||||
_recursive_set_additional_properties_false(value)
|
||||
if "items" in schema:
|
||||
_recursive_set_additional_properties_false(schema["items"])
|
||||
|
||||
return schema
|
||||
|
||||
@@ -182,7 +182,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}") from e
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain_core._api.deprecation import deprecated
|
||||
|
||||
@deprecated(
|
||||
since="0.1.30",
|
||||
removal="0.3",
|
||||
removal="1.0",
|
||||
message=(
|
||||
"Using the hwchase17/langchain-hub "
|
||||
"repo for prompts is deprecated. Please use "
|
||||
@@ -21,7 +21,9 @@ def try_load_from_hub(
|
||||
) -> Any:
|
||||
warnings.warn(
|
||||
"Loading from the deprecated github-based Hub is no longer supported. "
|
||||
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead."
|
||||
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# return None, which indicates that we shouldn't load from old hub
|
||||
# and might just be a filepath for e.g. load_chain
|
||||
|
||||
@@ -3,13 +3,17 @@ Adapted from https://github.com/noahmorrison/chevron
|
||||
MIT License
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@@ -22,7 +26,7 @@ from typing_extensions import TypeAlias
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Scopes: TypeAlias = List[Union[Literal[False, 0], Dict[str, Any]]]
|
||||
Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]]
|
||||
|
||||
|
||||
# Globals
|
||||
@@ -152,8 +156,8 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
|
||||
# Get the tag
|
||||
try:
|
||||
tag, template = template.split(r_del, 1)
|
||||
except ValueError:
|
||||
raise ChevronError("unclosed tag " "at line {0}".format(_CURRENT_LINE))
|
||||
except ValueError as e:
|
||||
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}") from e
|
||||
|
||||
# Find the type meaning of the first character
|
||||
tag_type = tag_types.get(tag[0], "variable")
|
||||
@@ -174,7 +178,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
|
||||
# Otherwise we should complain
|
||||
else:
|
||||
raise ChevronError(
|
||||
"unclosed set delimiter tag\n" "at line {0}".format(_CURRENT_LINE)
|
||||
"unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}"
|
||||
)
|
||||
|
||||
# If we might be a no html escape tag
|
||||
@@ -279,18 +283,18 @@ def tokenize(
|
||||
# is the same as us
|
||||
try:
|
||||
last_section = open_sections.pop()
|
||||
except IndexError:
|
||||
except IndexError as e:
|
||||
raise ChevronError(
|
||||
'Trying to close tag "{0}"\n'
|
||||
f'Trying to close tag "{tag_key}"\n'
|
||||
"Looks like it was not opened.\n"
|
||||
"line {1}".format(tag_key, _CURRENT_LINE + 1)
|
||||
)
|
||||
f"line {_CURRENT_LINE + 1}"
|
||||
) from e
|
||||
if tag_key != last_section:
|
||||
# Otherwise we need to complain
|
||||
raise ChevronError(
|
||||
'Trying to close tag "{0}"\n'
|
||||
'last open tag is "{1}"\n'
|
||||
"line {2}".format(tag_key, last_section, _CURRENT_LINE + 1)
|
||||
f'Trying to close tag "{tag_key}"\n'
|
||||
f'last open tag is "{last_section}"\n'
|
||||
f"line {_CURRENT_LINE + 1}"
|
||||
)
|
||||
|
||||
# Do the second check to see if we're a standalone
|
||||
@@ -320,8 +324,8 @@ def tokenize(
|
||||
# Then we need to complain
|
||||
raise ChevronError(
|
||||
"Unexpected EOF\n"
|
||||
'the tag "{0}" was never closed\n'
|
||||
"was opened at line {1}".format(open_sections[-1], _LAST_TAG_LINE)
|
||||
f'the tag "{open_sections[-1]}" was never closed\n'
|
||||
f"was opened at line {_LAST_TAG_LINE}"
|
||||
)
|
||||
|
||||
|
||||
@@ -403,15 +407,15 @@ def _get_key(
|
||||
# We couldn't find the key in any of the scopes
|
||||
|
||||
if warn:
|
||||
logger.warn("Could not find key '%s'" % (key))
|
||||
logger.warn(f"Could not find key '{key}'")
|
||||
|
||||
if keep:
|
||||
return "%s %s %s" % (def_ldel, key, def_rdel)
|
||||
return f"{def_ldel} {key} {def_rdel}"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
|
||||
def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
|
||||
"""Load a partial"""
|
||||
try:
|
||||
# Maybe the partial is in the dictionary
|
||||
@@ -425,11 +429,13 @@ def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
|
||||
#
|
||||
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
|
||||
|
||||
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
|
||||
|
||||
|
||||
def render(
|
||||
template: Union[str, List[Tuple[str, str]]] = "",
|
||||
data: Dict[str, Any] = {},
|
||||
partials_dict: Dict[str, str] = {},
|
||||
data: Mapping[str, Any] = EMPTY_DICT,
|
||||
partials_dict: Mapping[str, str] = EMPTY_DICT,
|
||||
padding: str = "",
|
||||
def_ldel: str = "{{",
|
||||
def_rdel: str = "}}",
|
||||
@@ -565,9 +571,9 @@ def render(
|
||||
if tag_type == "literal":
|
||||
text += tag_key
|
||||
elif tag_type == "no escape":
|
||||
text += "%s& %s %s" % (def_ldel, tag_key, def_rdel)
|
||||
text += f"{def_ldel}& {tag_key} {def_rdel}"
|
||||
else:
|
||||
text += "%s%s %s%s" % (
|
||||
text += "{}{} {}{}".format(
|
||||
def_ldel,
|
||||
{
|
||||
"comment": "!",
|
||||
|
||||
@@ -5,11 +5,12 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import textwrap
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from pydantic import BaseModel, root_validator # pydantic: ignore
|
||||
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue # pydantic: ignore
|
||||
from pydantic_core import core_schema # pydantic: ignore
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
@@ -26,9 +27,13 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
||||
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic.fields import FieldInfo as FieldInfoV1
|
||||
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
TypeBaseModel = Type[BaseModel]
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
|
||||
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
|
||||
@@ -142,7 +147,7 @@ def pre_init(func: Callable) -> Any:
|
||||
Dict[str, Any]: The values to initialize the model with.
|
||||
"""
|
||||
# Insert default values
|
||||
fields = cls.__fields__
|
||||
fields = cls.model_fields
|
||||
for name, field_info in fields.items():
|
||||
# Check if allow_population_by_field_name is enabled
|
||||
# If yes, then set the field name to the alias
|
||||
@@ -151,9 +156,13 @@ def pre_init(func: Callable) -> Any:
|
||||
if cls.Config.allow_population_by_field_name:
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
if hasattr(cls, "model_config"):
|
||||
if cls.model_config.get("populate_by_name"):
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
|
||||
if name not in values or values[name] is None:
|
||||
if not field_info.required:
|
||||
if not field_info.is_required():
|
||||
if field_info.default_factory is not None:
|
||||
values[name] = field_info.default_factory()
|
||||
else:
|
||||
@@ -165,6 +174,44 @@ def pre_init(func: Callable) -> Any:
|
||||
return wrapper
|
||||
|
||||
|
||||
class _IgnoreUnserializable(GenerateJsonSchema):
|
||||
"""A JSON schema generator that ignores unknown types.
|
||||
|
||||
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
|
||||
"""
|
||||
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: core_schema.CoreSchema, error_info: str
|
||||
) -> JsonSchemaValue:
|
||||
return {}
|
||||
|
||||
|
||||
def v1_repr(obj: BaseModel) -> str:
|
||||
"""Return the schema of the object as a string.
|
||||
|
||||
Get a repr for the pydantic object which is consistent with pydantic.v1.
|
||||
"""
|
||||
if not is_basemodel_instance(obj):
|
||||
raise TypeError(f"Expected a pydantic BaseModel, got {type(obj)}")
|
||||
repr_ = []
|
||||
for name, field in get_fields(obj).items():
|
||||
value = getattr(obj, name)
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
repr_.append(f"{name}={v1_repr(value)}")
|
||||
else:
|
||||
if not field.is_required():
|
||||
if not value:
|
||||
continue
|
||||
if field.default == value:
|
||||
continue
|
||||
|
||||
repr_.append(f"{name}={repr(value)}")
|
||||
|
||||
args = ", ".join(repr_)
|
||||
return f"{obj.__class__.__name__}({args})"
|
||||
|
||||
|
||||
def _create_subset_model_v1(
|
||||
name: str,
|
||||
model: Type[BaseModel],
|
||||
@@ -174,12 +221,20 @@ def _create_subset_model_v1(
|
||||
fn_description: Optional[str] = None,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
from langchain_core.pydantic_v1 import create_model
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import create_model # pydantic: ignore
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1 import create_model # type: ignore # pydantic: ignore
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
||||
)
|
||||
|
||||
fields = {}
|
||||
|
||||
for field_name in field_names:
|
||||
field = model.__fields__[field_name]
|
||||
# Using pydantic v1 so can access __fields__ as a dict.
|
||||
field = model.__fields__[field_name] # type: ignore
|
||||
t = (
|
||||
# this isn't perfect but should work for most functions
|
||||
field.outer_type_
|
||||
@@ -266,3 +321,48 @@ def _create_subset_model(
|
||||
raise NotImplementedError(
|
||||
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
||||
)
|
||||
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.fields import FieldInfo as FieldInfoV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
@overload
|
||||
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ...
|
||||
|
||||
def get_fields(
|
||||
model: Union[
|
||||
BaseModelV2,
|
||||
BaseModelV1,
|
||||
Type[BaseModelV2],
|
||||
Type[BaseModelV1],
|
||||
],
|
||||
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
return model.model_fields # type: ignore
|
||||
|
||||
elif hasattr(model, "__fields__"):
|
||||
return model.__fields__ # type: ignore
|
||||
else:
|
||||
raise TypeError(f"Expected a Pydantic model. Got {type(model)}")
|
||||
elif PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1_
|
||||
|
||||
def get_fields( # type: ignore[no-redef]
|
||||
model: Union[Type[BaseModelV1_], BaseModelV1_],
|
||||
) -> Dict[str, FieldInfoV1]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
return model.__fields__ # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||
|
||||
@@ -4,14 +4,15 @@ import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import warnings
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload
|
||||
|
||||
from packaging.version import parse
|
||||
from pydantic import SecretStr
|
||||
from requests import HTTPError, Response
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from langchain_core.utils.pydantic import (
|
||||
is_pydantic_v1_subclass,
|
||||
)
|
||||
@@ -130,12 +131,12 @@ def guard_import(
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(module_name, package)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||
raise ImportError(
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
)
|
||||
) from e
|
||||
return module
|
||||
|
||||
|
||||
@@ -234,7 +235,8 @@ def build_extra_kwargs(
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
Please confirm that {field_name} is what you intended.""",
|
||||
stacklevel=7,
|
||||
)
|
||||
extra_kwargs[field_name] = values.pop(field_name)
|
||||
|
||||
@@ -260,3 +262,155 @@ def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
||||
|
||||
|
||||
class _NoDefaultType:
|
||||
"""Type to indicate no default value is provided."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_NoDefault = _NoDefaultType()
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /, *, default: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: Sequence[str], /, *, default: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(
|
||||
key: Union[str, Sequence[str]], /, *, default: str, error_message: Optional[str]
|
||||
) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(
|
||||
key: str, /, *, default: None, error_message: Optional[str]
|
||||
) -> Callable[[], Optional[str]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(
|
||||
key: Union[str, Sequence[str]], /, *, default: None
|
||||
) -> Callable[[], Optional[str]]: ...
|
||||
|
||||
|
||||
def from_env(
|
||||
key: Union[str, Sequence[str]],
|
||||
/,
|
||||
*,
|
||||
default: Union[str, _NoDefaultType, None] = _NoDefault,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Union[Callable[[], str], Callable[[], Optional[str]]]:
|
||||
"""Create a factory method that gets a value from an environment variable.
|
||||
|
||||
Args:
|
||||
key: The environment variable to look up. If a list of keys is provided,
|
||||
the first key found in the environment will be used.
|
||||
If no key is found, the default value will be used if set,
|
||||
otherwise an error will be raised.
|
||||
default: The default value to return if the environment variable is not set.
|
||||
error_message: the error message which will be raised if the key is not found
|
||||
and no default value is provided.
|
||||
This will be raised as a ValueError.
|
||||
"""
|
||||
|
||||
def get_from_env_fn() -> Optional[str]:
|
||||
"""Get a value from an environment variable."""
|
||||
if isinstance(key, (list, tuple)):
|
||||
for k in key:
|
||||
if k in os.environ:
|
||||
return os.environ[k]
|
||||
if isinstance(key, str):
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
if isinstance(default, (str, type(None))):
|
||||
return default
|
||||
else:
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
return get_from_env_fn
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(
|
||||
key: Union[str, Sequence[str]], /, *, default: None
|
||||
) -> Callable[[], Optional[SecretStr]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretStr]: ...
|
||||
|
||||
|
||||
def secret_from_env(
|
||||
key: Union[str, Sequence[str]],
|
||||
/,
|
||||
*,
|
||||
default: Union[str, _NoDefaultType, None] = _NoDefault,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Union[Callable[[], Optional[SecretStr]], Callable[[], SecretStr]]:
|
||||
"""Secret from env.
|
||||
|
||||
Args:
|
||||
key: The environment variable to look up.
|
||||
default: The default value to return if the environment variable is not set.
|
||||
error_message: the error message which will be raised if the key is not found
|
||||
and no default value is provided.
|
||||
This will be raised as a ValueError.
|
||||
|
||||
Returns:
|
||||
factory method that will look up the secret from the environment.
|
||||
"""
|
||||
|
||||
def get_secret_from_env() -> Optional[SecretStr]:
|
||||
"""Get a value from an environment variable."""
|
||||
if isinstance(key, (list, tuple)):
|
||||
for k in key:
|
||||
if k in os.environ:
|
||||
return SecretStr(os.environ[k])
|
||||
if isinstance(key, str):
|
||||
if key in os.environ:
|
||||
return SecretStr(os.environ[key])
|
||||
if isinstance(default, str):
|
||||
return SecretStr(default)
|
||||
elif isinstance(default, type(None)):
|
||||
return None
|
||||
else:
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
return get_secret_from_env
|
||||
|
||||
@@ -29,30 +29,24 @@ from itertools import cycle
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core._api import beta
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.utils.aiter import abatch_iterate
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
@@ -60,7 +54,6 @@ if TYPE_CHECKING:
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.indexing import UpsertResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,7 +89,7 @@ class VectorStore(ABC):
|
||||
ValueError: If the number of metadatas does not match the number of texts.
|
||||
ValueError: If the number of ids does not match the number of texts.
|
||||
"""
|
||||
if type(self).upsert != VectorStore.upsert:
|
||||
if type(self).add_documents != VectorStore.add_documents:
|
||||
# Import document in local scope to avoid circular imports
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -109,190 +102,19 @@ class VectorStore(ABC):
|
||||
if metadatas and len(metadatas) != len(texts_):
|
||||
raise ValueError(
|
||||
"The number of metadatas must match the number of texts."
|
||||
"Got {len(metadatas)} metadatas and {len(texts_)} texts."
|
||||
f"Got {len(metadatas)} metadatas and {len(texts_)} texts."
|
||||
)
|
||||
|
||||
if "ids" in kwargs:
|
||||
ids = kwargs.pop("ids")
|
||||
if ids and len(ids) != len(texts_):
|
||||
raise ValueError(
|
||||
"The number of ids must match the number of texts."
|
||||
"Got {len(ids)} ids and {len(texts_)} texts."
|
||||
)
|
||||
else:
|
||||
ids = None
|
||||
|
||||
metadatas_ = iter(metadatas) if metadatas else cycle([{}])
|
||||
ids_: Iterable[Union[str, None]] = ids if ids is not None else cycle([None])
|
||||
docs = [
|
||||
Document(page_content=text, metadata=metadata_, id=id_)
|
||||
for text, metadata_, id_ in zip(texts, metadatas_, ids_)
|
||||
Document(page_content=text, metadata=metadata_)
|
||||
for text, metadata_ in zip(texts, metadatas_)
|
||||
]
|
||||
upsert_response = self.upsert(docs, **kwargs)
|
||||
return upsert_response["succeeded"]
|
||||
|
||||
return self.add_documents(docs, **kwargs)
|
||||
raise NotImplementedError(
|
||||
f"`add_texts` has not been implemented for {self.__class__.__name__} "
|
||||
)
|
||||
|
||||
# Developer guidelines:
|
||||
# Do not override streaming_upsert!
|
||||
@beta(message="Added in 0.2.11. The API is subject to change.")
|
||||
def streaming_upsert(
|
||||
self, items: Iterable[Document], /, batch_size: int, **kwargs: Any
|
||||
) -> Iterator[UpsertResponse]:
|
||||
"""Upsert documents in a streaming fashion.
|
||||
|
||||
Args:
|
||||
items: Iterable of Documents to add to the vectorstore.
|
||||
batch_size: The size of each batch to upsert.
|
||||
kwargs: Additional keyword arguments.
|
||||
kwargs should only include parameters that are common to all
|
||||
documents. (e.g., timeout for indexing, retry policy, etc.)
|
||||
kwargs should not include ids to avoid ambiguous semantics.
|
||||
Instead, the ID should be provided as part of the Document object.
|
||||
|
||||
Yields:
|
||||
UpsertResponse: A response object that contains the list of IDs that were
|
||||
successfully added or updated in the vectorstore and the list of IDs that
|
||||
failed to be added or updated.
|
||||
|
||||
.. versionadded:: 0.2.11
|
||||
"""
|
||||
# The default implementation of this method breaks the input into
|
||||
# batches of size `batch_size` and calls the `upsert` method on each batch.
|
||||
# Subclasses can override this method to provide a more efficient
|
||||
# implementation.
|
||||
for item_batch in batch_iterate(batch_size, items):
|
||||
yield self.upsert(item_batch, **kwargs)
|
||||
|
||||
# Please note that we've added a new method `upsert` instead of re-using the
|
||||
# existing `add_documents` method.
|
||||
# This was done to resolve potential ambiguities around the behavior of **kwargs
|
||||
# in existing add_documents / add_texts methods which could include per document
|
||||
# information (e.g., the `ids` parameter).
|
||||
# Over time the `add_documents` could be denoted as legacy and deprecated
|
||||
# in favor of the `upsert` method.
|
||||
@beta(message="Added in 0.2.11. The API is subject to change.")
|
||||
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
||||
"""Add or update documents in the vectorstore.
|
||||
|
||||
The upsert functionality should utilize the ID field of the Document object
|
||||
if it is provided. If the ID is not provided, the upsert method is free
|
||||
to generate an ID for the document.
|
||||
|
||||
When an ID is specified and the document already exists in the vectorstore,
|
||||
the upsert method should update the document with the new data. If the document
|
||||
does not exist, the upsert method should add the document to the vectorstore.
|
||||
|
||||
Args:
|
||||
items: Sequence of Documents to add to the vectorstore.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
UpsertResponse: A response object that contains the list of IDs that were
|
||||
successfully added or updated in the vectorstore and the list of IDs that
|
||||
failed to be added or updated.
|
||||
|
||||
.. versionadded:: 0.2.11
|
||||
"""
|
||||
# Developer guidelines:
|
||||
#
|
||||
# Vectorstores implementations are free to extend `upsert` implementation
|
||||
# to take in additional data per document.
|
||||
#
|
||||
# This data **SHOULD NOT** be part of the **kwargs** parameter, instead
|
||||
# sub-classes can use a Union type on `documents` to include additional
|
||||
# supported formats for the input data stream.
|
||||
#
|
||||
# For example,
|
||||
#
|
||||
# .. code-block:: python
|
||||
# from typing import TypedDict
|
||||
#
|
||||
# class DocumentWithVector(TypedDict):
|
||||
# document: Document
|
||||
# vector: List[float]
|
||||
#
|
||||
# def upsert(
|
||||
# self,
|
||||
# documents: Union[Iterable[Document], Iterable[DocumentWithVector]],
|
||||
# /,
|
||||
# **kwargs
|
||||
# ) -> UpsertResponse:
|
||||
# \"\"\"Add or update documents in the vectorstore.\"\"\"
|
||||
# # Implementation should check if documents is an
|
||||
# # iterable of DocumentWithVector or Document
|
||||
# pass
|
||||
#
|
||||
# Implementations that override upsert should include a new doc-string
|
||||
# that explains the semantics of upsert and includes in code
|
||||
# examples of how to insert using the alternate data formats.
|
||||
|
||||
# The implementation does not delegate to the `add_texts` method or
|
||||
# the `add_documents` method by default since those implementations
|
||||
raise NotImplementedError(
|
||||
f"upsert has not been implemented for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
@beta(message="Added in 0.2.11. The API is subject to change.")
|
||||
async def astreaming_upsert(
|
||||
self,
|
||||
items: AsyncIterable[Document],
|
||||
/,
|
||||
batch_size: int,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[UpsertResponse]:
|
||||
"""Upsert documents in a streaming fashion. Async version of streaming_upsert.
|
||||
|
||||
Args:
|
||||
items: Iterable of Documents to add to the vectorstore.
|
||||
batch_size: The size of each batch to upsert.
|
||||
kwargs: Additional keyword arguments.
|
||||
kwargs should only include parameters that are common to all
|
||||
documents. (e.g., timeout for indexing, retry policy, etc.)
|
||||
kwargs should not include ids to avoid ambiguous semantics.
|
||||
Instead the ID should be provided as part of the Document object.
|
||||
|
||||
Yields:
|
||||
UpsertResponse: A response object that contains the list of IDs that were
|
||||
successfully added or updated in the vectorstore and the list of IDs that
|
||||
failed to be added or updated.
|
||||
|
||||
.. versionadded:: 0.2.11
|
||||
"""
|
||||
async for batch in abatch_iterate(batch_size, items):
|
||||
yield await self.aupsert(batch, **kwargs)
|
||||
|
||||
@beta(message="Added in 0.2.11. The API is subject to change.")
|
||||
async def aupsert(
|
||||
self, items: Sequence[Document], /, **kwargs: Any
|
||||
) -> UpsertResponse:
|
||||
"""Add or update documents in the vectorstore. Async version of upsert.
|
||||
|
||||
The upsert functionality should utilize the ID field of the Document object
|
||||
if it is provided. If the ID is not provided, the upsert method is free
|
||||
to generate an ID for the document.
|
||||
|
||||
When an ID is specified and the document already exists in the vectorstore,
|
||||
the upsert method should update the document with the new data. If the document
|
||||
does not exist, the upsert method should add the document to the vectorstore.
|
||||
|
||||
Args:
|
||||
items: Sequence of Documents to add to the vectorstore.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
UpsertResponse: A response object that contains the list of IDs that were
|
||||
successfully added or updated in the vectorstore and the list of IDs that
|
||||
failed to be added or updated.
|
||||
|
||||
.. versionadded:: 0.2.11
|
||||
"""
|
||||
# Developer guidelines: See guidelines for the `upsert` method.
|
||||
# The implementation does not delegate to the `add_texts` method or
|
||||
# the `add_documents` method by default since those implementations
|
||||
return await run_in_executor(None, self.upsert, items, **kwargs)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
"""Access the query embedding object if available."""
|
||||
@@ -407,7 +229,7 @@ class VectorStore(ABC):
|
||||
ValueError: If the number of metadatas does not match the number of texts.
|
||||
ValueError: If the number of ids does not match the number of texts.
|
||||
"""
|
||||
if type(self).aupsert != VectorStore.aupsert:
|
||||
if type(self).aadd_documents != VectorStore.aadd_documents:
|
||||
# Import document in local scope to avoid circular imports
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -420,27 +242,16 @@ class VectorStore(ABC):
|
||||
if metadatas and len(metadatas) != len(texts_):
|
||||
raise ValueError(
|
||||
"The number of metadatas must match the number of texts."
|
||||
"Got {len(metadatas)} metadatas and {len(texts_)} texts."
|
||||
f"Got {len(metadatas)} metadatas and {len(texts_)} texts."
|
||||
)
|
||||
|
||||
if "ids" in kwargs:
|
||||
ids = kwargs.pop("ids")
|
||||
if ids and len(ids) != len(texts_):
|
||||
raise ValueError(
|
||||
"The number of ids must match the number of texts."
|
||||
"Got {len(ids)} ids and {len(texts_)} texts."
|
||||
)
|
||||
else:
|
||||
ids = None
|
||||
|
||||
metadatas_ = iter(metadatas) if metadatas else cycle([{}])
|
||||
ids_: Iterable[Union[str, None]] = ids if ids is not None else cycle([None])
|
||||
|
||||
docs = [
|
||||
Document(page_content=text, metadata=metadata_, id=id_)
|
||||
for text, metadata_, id_ in zip(texts, metadatas_, ids_)
|
||||
Document(page_content=text, metadata=metadata_)
|
||||
for text, metadata_ in zip(texts, metadatas_)
|
||||
]
|
||||
upsert_response = await self.aupsert(docs, **kwargs)
|
||||
return upsert_response["succeeded"]
|
||||
|
||||
return await self.aadd_documents(docs, **kwargs)
|
||||
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
@@ -458,37 +269,22 @@ class VectorStore(ABC):
|
||||
Raises:
|
||||
ValueError: If the number of ids does not match the number of documents.
|
||||
"""
|
||||
if type(self).upsert != VectorStore.upsert:
|
||||
from langchain_core.documents import Document
|
||||
if type(self).add_texts != VectorStore.add_texts:
|
||||
if "ids" not in kwargs:
|
||||
ids = [doc.id for doc in documents]
|
||||
|
||||
if "ids" in kwargs:
|
||||
ids = kwargs.pop("ids")
|
||||
if ids and len(ids) != len(documents):
|
||||
raise ValueError(
|
||||
"The number of ids must match the number of documents. "
|
||||
"Got {len(ids)} ids and {len(documents)} documents."
|
||||
)
|
||||
# If there's at least one valid ID, we'll assume that IDs
|
||||
# should be used.
|
||||
if any(ids):
|
||||
kwargs["ids"] = ids
|
||||
|
||||
documents_ = []
|
||||
|
||||
for id_, document in zip(ids, documents):
|
||||
doc_with_id = Document(
|
||||
page_content=document.page_content,
|
||||
metadata=document.metadata,
|
||||
id=id_,
|
||||
)
|
||||
documents_.append(doc_with_id)
|
||||
else:
|
||||
documents_ = documents
|
||||
|
||||
# If upsert has been implemented, we can use it to add documents
|
||||
return self.upsert(documents_, **kwargs)["succeeded"]
|
||||
|
||||
# Code path that delegates to add_text for backwards compatibility
|
||||
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return self.add_texts(texts, metadatas, **kwargs)
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return self.add_texts(texts, metadatas, **kwargs)
|
||||
raise NotImplementedError(
|
||||
f"`add_documents` and `add_texts` has not been implemented "
|
||||
f"for {self.__class__.__name__} "
|
||||
)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
@@ -506,41 +302,21 @@ class VectorStore(ABC):
|
||||
Raises:
|
||||
ValueError: If the number of IDs does not match the number of documents.
|
||||
"""
|
||||
# If either upsert or aupsert has been implemented, we delegate to them!
|
||||
if (
|
||||
type(self).aupsert != VectorStore.aupsert
|
||||
or type(self).upsert != VectorStore.upsert
|
||||
):
|
||||
# If aupsert has been implemented, we can use it to add documents
|
||||
from langchain_core.documents import Document
|
||||
# If the async method has been overridden, we'll use that.
|
||||
if type(self).aadd_texts != VectorStore.aadd_texts:
|
||||
if "ids" not in kwargs:
|
||||
ids = [doc.id for doc in documents]
|
||||
|
||||
if "ids" in kwargs:
|
||||
ids = kwargs.pop("ids")
|
||||
if ids and len(ids) != len(documents):
|
||||
raise ValueError(
|
||||
"The number of ids must match the number of documents."
|
||||
"Got {len(ids)} ids and {len(documents)} documents."
|
||||
)
|
||||
# If there's at least one valid ID, we'll assume that IDs
|
||||
# should be used.
|
||||
if any(ids):
|
||||
kwargs["ids"] = ids
|
||||
|
||||
documents_ = []
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return await self.aadd_texts(texts, metadatas, **kwargs)
|
||||
|
||||
for id_, document in zip(ids, documents):
|
||||
doc_with_id = Document(
|
||||
page_content=document.page_content,
|
||||
metadata=document.metadata,
|
||||
id=id_,
|
||||
)
|
||||
documents_.append(doc_with_id)
|
||||
else:
|
||||
documents_ = documents
|
||||
|
||||
# The default implementation of aupsert delegates to upsert.
|
||||
upsert_response = await self.aupsert(documents_, **kwargs)
|
||||
return upsert_response["succeeded"]
|
||||
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return await self.aadd_texts(texts, metadatas, **kwargs)
|
||||
return await run_in_executor(None, self.add_documents, documents, **kwargs)
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
|
||||
"""Return docs most similar to query using a specified search type.
|
||||
@@ -783,7 +559,8 @@ class VectorStore(ABC):
|
||||
):
|
||||
warnings.warn(
|
||||
"Relevance scores must be between"
|
||||
f" 0 and 1, got {docs_and_similarities}"
|
||||
f" 0 and 1, got {docs_and_similarities}",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
@@ -793,7 +570,7 @@ class VectorStore(ABC):
|
||||
if similarity >= score_threshold
|
||||
]
|
||||
if len(docs_and_similarities) == 0:
|
||||
warnings.warn(
|
||||
logger.warning(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
@@ -830,7 +607,8 @@ class VectorStore(ABC):
|
||||
):
|
||||
warnings.warn(
|
||||
"Relevance scores must be between"
|
||||
f" 0 and 1, got {docs_and_similarities}"
|
||||
f" 0 and 1, got {docs_and_similarities}",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
@@ -840,7 +618,7 @@ class VectorStore(ABC):
|
||||
if similarity >= score_threshold
|
||||
]
|
||||
if len(docs_and_similarities) == 0:
|
||||
warnings.warn(
|
||||
logger.warning(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
@@ -1207,11 +985,13 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
"mmr",
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_search_type(cls, values: Dict) -> Any:
|
||||
"""Validate search type.
|
||||
|
||||
Args:
|
||||
@@ -1239,6 +1019,25 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
|
||||
"""Get standard params for tracing."""
|
||||
|
||||
ls_params = super()._get_ls_params(**kwargs)
|
||||
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__
|
||||
|
||||
if self.vectorstore.embeddings:
|
||||
ls_params["ls_embedding_provider"] = (
|
||||
self.vectorstore.embeddings.__class__.__name__
|
||||
)
|
||||
elif hasattr(self.vectorstore, "embedding") and isinstance(
|
||||
self.vectorstore.embedding, Embeddings
|
||||
):
|
||||
ls_params["ls_embedding_provider"] = (
|
||||
self.vectorstore.embedding.__class__.__name__
|
||||
)
|
||||
|
||||
return ls_params
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
|
||||
@@ -8,30 +8,142 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langchain_core.vectorstores.utils import _cosine_similarity as cosine_similarity
|
||||
from langchain_core.vectorstores.utils import (
|
||||
_maximal_marginal_relevance as maximal_marginal_relevance,
|
||||
)
|
||||
from langchain_core.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.indexing import UpsertResponse
|
||||
|
||||
|
||||
class InMemoryVectorStore(VectorStore):
|
||||
"""In-memory implementation of VectorStore using a dictionary.
|
||||
"""In-memory vector store implementation.
|
||||
|
||||
Uses numpy to compute cosine similarity for search.
|
||||
"""
|
||||
Uses a dictionary, and computes cosine similarity for search using numpy.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-core``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-core
|
||||
|
||||
Key init args — indexing params:
|
||||
embedding_function: Embeddings
|
||||
Embedding function to use.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vector_store = InMemoryVectorStore(OpenAIEmbeddings())
|
||||
|
||||
Add Documents:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
document_1 = Document(id="1", page_content="foo", metadata={"baz": "bar"})
|
||||
document_2 = Document(id="2", page_content="thud", metadata={"bar": "baz"})
|
||||
document_3 = Document(id="3", page_content="i will be deleted :(")
|
||||
|
||||
documents = [document_1, document_2, document_3]
|
||||
vector_store.add_documents(documents=documents)
|
||||
|
||||
Delete Documents:
|
||||
.. code-block:: python
|
||||
|
||||
vector_store.delete(ids=["3"])
|
||||
|
||||
Search:
|
||||
.. code-block:: python
|
||||
|
||||
results = vector_store.similarity_search(query="thud",k=1)
|
||||
for doc in results:
|
||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
* thud [{'bar': 'baz'}]
|
||||
|
||||
Search with filter:
|
||||
.. code-block:: python
|
||||
|
||||
def _filter_function(doc: Document) -> bool:
|
||||
return doc.metadata.get("bar") == "baz"
|
||||
|
||||
results = vector_store.similarity_search(
|
||||
query="thud", k=1, filter=_filter_function
|
||||
)
|
||||
for doc in results:
|
||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
* thud [{'bar': 'baz'}]
|
||||
|
||||
|
||||
Search with score:
|
||||
.. code-block:: python
|
||||
|
||||
results = vector_store.similarity_search_with_score(
|
||||
query="qux", k=1
|
||||
)
|
||||
for doc, score in results:
|
||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
* [SIM=0.832268] foo [{'baz': 'bar'}]
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
|
||||
# add documents
|
||||
# await vector_store.aadd_documents(documents=documents)
|
||||
|
||||
# delete documents
|
||||
# await vector_store.adelete(ids=["3"])
|
||||
|
||||
# search
|
||||
# results = vector_store.asimilarity_search(query="thud", k=1)
|
||||
|
||||
# search with score
|
||||
results = await vector_store.asimilarity_search_with_score(query="qux", k=1)
|
||||
for doc,score in results:
|
||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
* [SIM=0.832268] foo [{'baz': 'bar'}]
|
||||
|
||||
Use as Retriever:
|
||||
.. code-block:: python
|
||||
|
||||
retriever = vector_store.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
|
||||
)
|
||||
retriever.invoke("thud")
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
[Document(id='2', metadata={'bar': 'baz'}, page_content='thud')]
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, embedding: Embeddings) -> None:
|
||||
"""Initialize with the given embedding function.
|
||||
@@ -56,43 +168,71 @@ class InMemoryVectorStore(VectorStore):
|
||||
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||
self.delete(ids)
|
||||
|
||||
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
||||
vectors = self.embedding.embed_documents([item.page_content for item in items])
|
||||
ids = []
|
||||
for item, vector in zip(items, vectors):
|
||||
doc_id = item.id if item.id else str(uuid.uuid4())
|
||||
ids.append(doc_id)
|
||||
self.store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"vector": vector,
|
||||
"text": item.page_content,
|
||||
"metadata": item.metadata,
|
||||
}
|
||||
return {
|
||||
"succeeded": ids,
|
||||
"failed": [],
|
||||
}
|
||||
def add_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add documents to the store."""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
vectors = self.embedding.embed_documents(texts)
|
||||
|
||||
async def aupsert(
|
||||
self, items: Sequence[Document], /, **kwargs: Any
|
||||
) -> UpsertResponse:
|
||||
vectors = await self.embedding.aembed_documents(
|
||||
[item.page_content for item in items]
|
||||
if ids and len(ids) != len(texts):
|
||||
raise ValueError(
|
||||
f"ids must be the same length as texts. "
|
||||
f"Got {len(ids)} ids and {len(texts)} texts."
|
||||
)
|
||||
|
||||
id_iterator: Iterator[Optional[str]] = (
|
||||
iter(ids) if ids else iter(doc.id for doc in documents)
|
||||
)
|
||||
ids = []
|
||||
for item, vector in zip(items, vectors):
|
||||
doc_id = item.id if item.id else str(uuid.uuid4())
|
||||
ids.append(doc_id)
|
||||
self.store[doc_id] = {
|
||||
"id": doc_id,
|
||||
|
||||
ids_ = []
|
||||
|
||||
for doc, vector in zip(documents, vectors):
|
||||
doc_id = next(id_iterator)
|
||||
doc_id_ = doc_id if doc_id else str(uuid.uuid4())
|
||||
ids_.append(doc_id_)
|
||||
self.store[doc_id_] = {
|
||||
"id": doc_id_,
|
||||
"vector": vector,
|
||||
"text": item.page_content,
|
||||
"metadata": item.metadata,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
}
|
||||
return {
|
||||
"succeeded": ids,
|
||||
"failed": [],
|
||||
}
|
||||
|
||||
return ids_
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], ids: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Add documents to the store."""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
vectors = await self.embedding.aembed_documents(texts)
|
||||
|
||||
if ids and len(ids) != len(texts):
|
||||
raise ValueError(
|
||||
f"ids must be the same length as texts. "
|
||||
f"Got {len(ids)} ids and {len(texts)} texts."
|
||||
)
|
||||
|
||||
id_iterator: Iterator[Optional[str]] = (
|
||||
iter(ids) if ids else iter(doc.id for doc in documents)
|
||||
)
|
||||
ids_: List[str] = []
|
||||
|
||||
for doc, vector in zip(documents, vectors):
|
||||
doc_id = next(id_iterator)
|
||||
doc_id_ = doc_id if doc_id else str(uuid.uuid4())
|
||||
ids_.append(doc_id_)
|
||||
self.store[doc_id_] = {
|
||||
"id": doc_id_,
|
||||
"vector": vector,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
}
|
||||
|
||||
return ids_
|
||||
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
||||
"""Get documents by their ids.
|
||||
@@ -117,6 +257,62 @@ class InMemoryVectorStore(VectorStore):
|
||||
)
|
||||
return documents
|
||||
|
||||
@deprecated(
|
||||
alternative="VectorStore.add_documents",
|
||||
message=(
|
||||
"This was a beta API that was added in 0.2.11. "
|
||||
"It'll be removed in 0.3.0."
|
||||
),
|
||||
since="0.2.29",
|
||||
removal="1.0",
|
||||
)
|
||||
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
||||
vectors = self.embedding.embed_documents([item.page_content for item in items])
|
||||
ids = []
|
||||
for item, vector in zip(items, vectors):
|
||||
doc_id = item.id if item.id else str(uuid.uuid4())
|
||||
ids.append(doc_id)
|
||||
self.store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"vector": vector,
|
||||
"text": item.page_content,
|
||||
"metadata": item.metadata,
|
||||
}
|
||||
return {
|
||||
"succeeded": ids,
|
||||
"failed": [],
|
||||
}
|
||||
|
||||
@deprecated(
|
||||
alternative="VectorStore.aadd_documents",
|
||||
message=(
|
||||
"This was a beta API that was added in 0.2.11. "
|
||||
"It'll be removed in 0.3.0."
|
||||
),
|
||||
since="0.2.29",
|
||||
removal="1.0",
|
||||
)
|
||||
async def aupsert(
|
||||
self, items: Sequence[Document], /, **kwargs: Any
|
||||
) -> UpsertResponse:
|
||||
vectors = await self.embedding.aembed_documents(
|
||||
[item.page_content for item in items]
|
||||
)
|
||||
ids = []
|
||||
for item, vector in zip(items, vectors):
|
||||
doc_id = item.id if item.id else str(uuid.uuid4())
|
||||
ids.append(doc_id)
|
||||
self.store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"vector": vector,
|
||||
"text": item.page_content,
|
||||
"metadata": item.metadata,
|
||||
}
|
||||
return {
|
||||
"succeeded": ids,
|
||||
"failed": [],
|
||||
}
|
||||
|
||||
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
||||
"""Async get documents by their ids.
|
||||
|
||||
@@ -239,11 +435,11 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"numpy must be installed to use max_marginal_relevance_search "
|
||||
"pip install numpy"
|
||||
)
|
||||
) from e
|
||||
|
||||
mmr_chosen_indices = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
@@ -294,7 +490,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "InMemoryVectorStore":
|
||||
) -> InMemoryVectorStore:
|
||||
store = cls(
|
||||
embedding=embedding,
|
||||
)
|
||||
@@ -308,7 +504,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "InMemoryVectorStore":
|
||||
) -> InMemoryVectorStore:
|
||||
store = cls(
|
||||
embedding=embedding,
|
||||
)
|
||||
@@ -318,7 +514,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
@classmethod
|
||||
def load(
|
||||
cls, path: str, embedding: Embeddings, **kwargs: Any
|
||||
) -> "InMemoryVectorStore":
|
||||
) -> InMemoryVectorStore:
|
||||
"""Load a vector store from a file.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -34,11 +34,11 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"cosine_similarity requires numpy to be installed. "
|
||||
"Please install numpy with `pip install numpy`."
|
||||
)
|
||||
) from e
|
||||
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return np.array([])
|
||||
@@ -51,7 +51,7 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
f"and Y has shape {Y.shape}."
|
||||
)
|
||||
try:
|
||||
import simsimd as simd # type: ignore
|
||||
import simsimd as simd # type: ignore[import-not-found]
|
||||
|
||||
X = np.array(X, dtype=np.float32)
|
||||
Y = np.array(Y, dtype=np.float32)
|
||||
@@ -71,7 +71,7 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
return similarity
|
||||
|
||||
|
||||
def _maximal_marginal_relevance(
|
||||
def maximal_marginal_relevance(
|
||||
query_embedding: np.ndarray,
|
||||
embedding_list: list,
|
||||
lambda_mult: float = 0.5,
|
||||
@@ -93,11 +93,11 @@ def _maximal_marginal_relevance(
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"maximal_marginal_relevance requires numpy to be installed. "
|
||||
"Please install numpy with `pip install numpy`."
|
||||
)
|
||||
) from e
|
||||
|
||||
if min(k, len(embedding_list)) <= 0:
|
||||
return []
|
||||
|
||||
1168
libs/core/poetry.lock
generated
1168
libs/core/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "langchain-core"
|
||||
version = "0.2.29rc1"
|
||||
version = "0.2.37"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
@@ -41,7 +41,8 @@ python = ">=3.12.4"
|
||||
[tool.poetry.extras]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "E", "F", "I", "T201",]
|
||||
select = [ "B", "E", "F", "I", "T201", "UP",]
|
||||
ignore = [ "UP006", "UP007",]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [ "tests/*",]
|
||||
@@ -79,6 +80,7 @@ mypy = ">=1.10,<1.11"
|
||||
types-pyyaml = "^6.0.12.2"
|
||||
types-requests = "^2.28.11.5"
|
||||
types-jinja2 = "^2.11.9"
|
||||
simsimd = "^5.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyter = "^1.0.0"
|
||||
|
||||
@@ -3,9 +3,9 @@ import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core._api.beta_decorator import beta, warn_beta
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -3,9 +3,13 @@ import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core._api.deprecation import deprecated, warn_deprecated
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core._api.deprecation import (
|
||||
deprecated,
|
||||
rename_parameter,
|
||||
warn_deprecated,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -412,3 +416,84 @@ def test_raise_error_for_bad_decorator() -> None:
|
||||
def deprecated_function() -> str:
|
||||
"""original doc"""
|
||||
return "This is a deprecated function."
|
||||
|
||||
|
||||
def test_rename_parameter() -> None:
|
||||
"""Test rename parameter."""
|
||||
|
||||
@rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name")
|
||||
def foo(new_name: str) -> str:
|
||||
"""original doc"""
|
||||
return new_name
|
||||
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
warnings.simplefilter("always")
|
||||
assert foo(old_name="hello") == "hello" # type: ignore[call-arg]
|
||||
assert len(warning_list) == 1
|
||||
|
||||
assert foo(new_name="hello") == "hello"
|
||||
assert foo("hello") == "hello"
|
||||
assert foo.__doc__ == "original doc"
|
||||
with pytest.raises(TypeError):
|
||||
foo(meow="hello") # type: ignore[call-arg]
|
||||
with pytest.raises(TypeError):
|
||||
assert foo("hello", old_name="hello") # type: ignore[call-arg]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
assert foo(old_name="goodbye", new_name="hello") # type: ignore[call-arg]
|
||||
|
||||
|
||||
async def test_rename_parameter_for_async_func() -> None:
|
||||
"""Test rename parameter."""
|
||||
|
||||
@rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name")
|
||||
async def foo(new_name: str) -> str:
|
||||
"""original doc"""
|
||||
return new_name
|
||||
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
warnings.simplefilter("always")
|
||||
assert await foo(old_name="hello") == "hello" # type: ignore[call-arg]
|
||||
assert len(warning_list) == 1
|
||||
assert await foo(new_name="hello") == "hello"
|
||||
assert await foo("hello") == "hello"
|
||||
assert foo.__doc__ == "original doc"
|
||||
with pytest.raises(TypeError):
|
||||
await foo(meow="hello") # type: ignore[call-arg]
|
||||
with pytest.raises(TypeError):
|
||||
assert await foo("hello", old_name="hello") # type: ignore[call-arg]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
assert await foo(old_name="a", new_name="hello") # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_rename_parameter_method() -> None:
|
||||
"""Test that it works for a method."""
|
||||
|
||||
class Foo:
|
||||
@rename_parameter(
|
||||
since="2.0.0", removal="3.0.0", old="old_name", new="new_name"
|
||||
)
|
||||
def a(self, new_name: str) -> str:
|
||||
return new_name
|
||||
|
||||
foo = Foo()
|
||||
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
warnings.simplefilter("always")
|
||||
assert foo.a(old_name="hello") == "hello" # type: ignore[call-arg]
|
||||
assert len(warning_list) == 1
|
||||
assert str(warning_list[0].message) == (
|
||||
"The parameter `old_name` of `a` was deprecated in 2.0.0 and will be "
|
||||
"removed "
|
||||
"in 3.0.0 Use `new_name` instead."
|
||||
)
|
||||
|
||||
assert foo.a(new_name="hello") == "hello"
|
||||
assert foo.a("hello") == "hello"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
foo.a(meow="hello") # type: ignore[call-arg]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
assert foo.a("hello", old_name="hello") # type: ignore[call-arg]
|
||||
|
||||
@@ -4,9 +4,10 @@ from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@@ -256,7 +257,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retriever_error_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
|
||||
return self
|
||||
|
||||
|
||||
@@ -390,5 +392,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
|
||||
return self
|
||||
|
||||
@@ -9,7 +9,6 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.stubs import (
|
||||
AnyStr,
|
||||
_AnyIdAIMessage,
|
||||
_AnyIdAIMessageChunk,
|
||||
_AnyIdHumanMessage,
|
||||
@@ -70,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@@ -89,29 +88,23 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
|
||||
assert chunks == [
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"name": "move_file"}},
|
||||
id=AnyStr(),
|
||||
_AnyIdAIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
_AnyIdAIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'},
|
||||
},
|
||||
id=AnyStr(),
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"arguments": ","}},
|
||||
id=AnyStr(),
|
||||
_AnyIdAIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
_AnyIdAIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
|
||||
},
|
||||
id=AnyStr(),
|
||||
),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Test base chat model."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel, FakeListChatModel
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@@ -18,6 +19,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
BaseFakeCallbackHandler,
|
||||
@@ -109,7 +111,7 @@ async def test_stream_error_callback() -> None:
|
||||
responses=[message],
|
||||
error_on_chunk_number=i,
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(FakeListChatModelError):
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||
pass
|
||||
@@ -272,3 +274,96 @@ async def test_async_pass_run_id() -> None:
|
||||
uid3 = uuid.uuid4()
|
||||
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
|
||||
assert cb.traced_run_ids == [uid1, uid2, uid3]
|
||||
|
||||
|
||||
class NoStreamingModel(BaseChatModel):
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "model1"
|
||||
|
||||
|
||||
class StreamingModel(NoStreamingModel):
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="stream"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
|
||||
async def test_disable_streaming(
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]],
|
||||
) -> None:
|
||||
model = StreamingModel(disable_streaming=disable_streaming)
|
||||
assert model.invoke([]).content == "invoke"
|
||||
assert (await model.ainvoke([])).content == "invoke"
|
||||
|
||||
expected = "invoke" if disable_streaming is True else "stream"
|
||||
assert next(model.stream([])).content == expected
|
||||
async for c in model.astream([]):
|
||||
assert c.content == expected
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
|
||||
).content
|
||||
== expected
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
|
||||
).content == expected
|
||||
|
||||
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
|
||||
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
|
||||
async for c in model.astream([], tools=[{}]):
|
||||
assert c.content == expected
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
|
||||
).content
|
||||
== expected
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
|
||||
)
|
||||
).content == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
|
||||
async def test_disable_streaming_no_streaming_model(
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]],
|
||||
) -> None:
|
||||
model = NoStreamingModel(disable_streaming=disable_streaming)
|
||||
assert model.invoke([]).content == "invoke"
|
||||
assert (await model.ainvoke([])).content == "invoke"
|
||||
assert next(model.stream([])).content == "invoke"
|
||||
async for c in model.astream([]):
|
||||
assert c.content == "invoke"
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
|
||||
).content
|
||||
== "invoke"
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
|
||||
).content == "invoke"
|
||||
assert next(model.stream([], tools=[{}])).content == "invoke"
|
||||
async for c in model.astream([], tools=[{}]):
|
||||
assert c.content == "invoke"
|
||||
break
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.caches import InMemoryCache
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
@@ -220,6 +221,9 @@ class SerializableModel(GenericFakeChatModel):
|
||||
return True
|
||||
|
||||
|
||||
SerializableModel.model_rebuild()
|
||||
|
||||
|
||||
def test_serialization_with_rate_limiter() -> None:
|
||||
"""Test model serialization with rate limiter."""
|
||||
from langchain_core.load import dumps
|
||||
|
||||
@@ -7,6 +7,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
|
||||
from langchain_core.language_models.fake import FakeListLLMError
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
@@ -108,7 +109,7 @@ async def test_stream_error_callback() -> None:
|
||||
responses=[message],
|
||||
error_on_chunk_number=i,
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(FakeListLLMError):
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||
pass
|
||||
|
||||
@@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
||||
"SimpleChatModel",
|
||||
"BaseLLM",
|
||||
"LLM",
|
||||
"LangSmithParams",
|
||||
"LanguageModelInput",
|
||||
"LanguageModelOutput",
|
||||
"LanguageModelLike",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user