diff --git a/docs/docs/how_to/merge_message_runs.ipynb b/docs/docs/how_to/merge_message_runs.ipynb index 61dd3e49a8a..8ef019d8fe0 100644 --- a/docs/docs/how_to/merge_message_runs.ipynb +++ b/docs/docs/how_to/merge_message_runs.ipynb @@ -63,6 +63,38 @@ "Notice that if the contents of one of the messages to merge is a list of content blocks then the merged message will have a list of content blocks. And if both messages to merge have string contents then those are concatenated with a newline character." ] }, + { + "cell_type": "markdown", + "id": "11f7e8d3", + "metadata": {}, + "source": [ + "The `merge_message_runs` utility also works with messages composed together using the overloaded `+` operation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b51855c5", + "metadata": {}, + "outputs": [], + "source": [ + "messages = (\n", + " SystemMessage(\"you're a good assistant.\")\n", + " + SystemMessage(\"you always respond with a joke.\")\n", + " + HumanMessage([{\"type\": \"text\", \"text\": \"i wonder why it's called langchain\"}])\n", + " + HumanMessage(\"and who is harrison chasing anyways\")\n", + " + AIMessage(\n", + " 'Well, I guess they thought \"WordRope\" and \"SentenceString\" just didn\\'t have the same ring to it!'\n", + " )\n", + " + AIMessage(\n", + " \"Why, he's probably chasing after the last cup of coffee in the office!\"\n", + " )\n", + ")\n", + "\n", + "merged = merge_message_runs(messages)\n", + "print(\"\\n\\n\".join([repr(x) for x in merged]))" + ] + }, { "cell_type": "markdown", "id": "1b2eee74-71c8-4168-b968-bca580c25d18", diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index f48075b8b08..32207673747 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -16,6 +16,7 @@ from typing import ( Any, Callable, Dict, + Iterable, List, Literal, Optional, @@ -40,6 +41,7 @@ if TYPE_CHECKING: from langchain_text_splitters import TextSplitter from langchain_core.language_models import BaseLanguageModel + from langchain_core.prompt_values import PromptValue from langchain_core.runnables.base import Runnable AnyMessage = Union[ @@ -284,7 +286,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: def convert_to_messages( - messages: Sequence[MessageLikeRepresentation], + messages: Union[Iterable[MessageLikeRepresentation], PromptValue], ) -> List[BaseMessage]: """Convert a sequence of messages to a list of messages. @@ -294,6 +296,11 @@ def convert_to_messages( Returns: List of messages (BaseMessages). """ + # Import here to avoid circular imports + from langchain_core.prompt_values import PromptValue + + if isinstance(messages, PromptValue): + return messages.to_messages() return [_convert_to_message(m) for m in messages] @@ -329,7 +336,7 @@ def _runnable_support(func: Callable) -> Callable: @_runnable_support def filter_messages( - messages: Sequence[MessageLikeRepresentation], + messages: Union[Iterable[MessageLikeRepresentation], PromptValue], *, include_names: Optional[Sequence[str]] = None, exclude_names: Optional[Sequence[str]] = None, @@ -417,7 +424,7 @@ def filter_messages( @_runnable_support def merge_message_runs( - messages: Sequence[MessageLikeRepresentation], + messages: Union[Iterable[MessageLikeRepresentation], PromptValue], ) -> List[BaseMessage]: """Merge consecutive Messages of the same type. @@ -506,7 +513,7 @@ def merge_message_runs( @_runnable_support def trim_messages( - messages: Sequence[MessageLikeRepresentation], + messages: Union[Iterable[MessageLikeRepresentation], PromptValue], *, max_tokens: int, token_counter: Union[ diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 3959d7df287..d3160cb47ed 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -127,7 +127,6 @@ _MESSAGES_TO_TRIM = [ HumanMessage("This is a 4 token text.", id="third"), AIMessage("This is a 4 token text.", id="fourth"), ] - _MESSAGES_TO_TRIM_COPY = [m.copy(deep=True) for m in _MESSAGES_TO_TRIM]