From e8a82860127ebd572841f9beaab003c3284d9e9a Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 18 Jun 2024 17:15:18 -0700 Subject: [PATCH] core[patch]: runnablewithchathistory from core.runnables (#23136) --- libs/core/langchain_core/messages/utils.py | 55 ++----------------- .../core/langchain_core/runnables/__init__.py | 2 + libs/core/langchain_core/runnables/history.py | 4 +- .../unit_tests/runnables/test_imports.py | 1 + 4 files changed, 9 insertions(+), 53 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index dac0ec679f4..74139e501ce 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -25,12 +25,12 @@ from langchain_core.messages.function import FunctionMessage, FunctionMessageChu from langchain_core.messages.human import HumanMessage, HumanMessageChunk from langchain_core.messages.system import SystemMessage, SystemMessageChunk from langchain_core.messages.tool import ToolMessage, ToolMessageChunk -from langchain_core.runnables import Runnable, RunnableLambda if TYPE_CHECKING: from langchain_text_splitters import TextSplitter from langchain_core.language_models import BaseLanguageModel + from langchain_core.runnables.base import Runnable AnyMessage = Union[ AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage @@ -279,6 +279,8 @@ def _runnable_support(func: Callable) -> Callable: List[BaseMessage], Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]], ]: + from langchain_core.runnables.base import RunnableLambda + if messages is not None: return func(messages, **kwargs) else: @@ -486,9 +488,7 @@ def trim_messages( ] = None, include_system: bool = False, text_splitter: Optional[Union[Callable[[str], List[str]], TextSplitter]] = None, -) -> Union[ - List[BaseMessage], Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]] -]: +) -> List[BaseMessage]: """Trim messages to be below a token count. Args: @@ -734,53 +734,6 @@ def trim_messages( """ # noqa: E501 - if messages is not None: - return _trim_messages_helper( - messages, - max_tokens=max_tokens, - token_counter=token_counter, - strategy=strategy, - allow_partial=allow_partial, - end_on=end_on, - start_on=start_on, - include_system=include_system, - text_splitter=text_splitter, - ) - else: - trimmer = partial( - _trim_messages_helper, - max_tokens=max_tokens, - token_counter=token_counter, - strategy=strategy, - allow_partial=allow_partial, - end_on=end_on, - start_on=start_on, - include_system=include_system, - text_splitter=text_splitter, - ) - return RunnableLambda(trimmer) - - -def _trim_messages_helper( - messages: Sequence[MessageLikeRepresentation], - *, - max_tokens: int, - token_counter: Union[ - Callable[[List[BaseMessage]], int], - Callable[[BaseMessage], int], - BaseLanguageModel, - ], - strategy: Literal["first", "last"] = "last", - allow_partial: bool = False, - end_on: Optional[ - Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] - ] = None, - start_on: Optional[ - Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] - ] = None, - include_system: bool = False, - text_splitter: Optional[Union[Callable[[str], List[str]], TextSplitter]] = None, -) -> List[BaseMessage]: from langchain_core.language_models import BaseLanguageModel if start_on and strategy == "first": diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py index 9c4362c9c86..518d5c38c94 100644 --- a/libs/core/langchain_core/runnables/__init__.py +++ b/libs/core/langchain_core/runnables/__init__.py @@ -36,6 +36,7 @@ from langchain_core.runnables.config import ( run_in_executor, ) from langchain_core.runnables.fallbacks import RunnableWithFallbacks +from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.passthrough import ( RunnableAssign, RunnablePassthrough, @@ -78,6 +79,7 @@ __all__ = [ "RunnablePick", "RunnableSequence", "RunnableWithFallbacks", + "RunnableWithMessageHistory", "get_config_list", "aadd", "add", diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index fdc40e55835..326f941263c 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -25,8 +25,8 @@ from langchain_core.runnables.utils import ( ) if TYPE_CHECKING: - from langchain_core.language_models import LanguageModelLike - from langchain_core.messages import BaseMessage + from langchain_core.language_models.base import LanguageModelLike + from langchain_core.messages.base import BaseMessage from langchain_core.runnables.config import RunnableConfig from langchain_core.tracers.schemas import Run diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index 48098aa6db3..12b1a80d1bf 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -26,6 +26,7 @@ EXPECTED_ALL = [ "RunnablePick", "RunnableSequence", "RunnableWithFallbacks", + "RunnableWithMessageHistory", "get_config_list", "aadd", "add",