From 7d13a2f958235774ec513fcd9c7ac421e5cbf29f Mon Sep 17 00:00:00 2001 From: CastaChick Date: Fri, 23 Aug 2024 04:46:25 +0900 Subject: [PATCH] core[patch]: add option to specify the chunk separator in `merge_message_runs` (#24783) **Description:** LLM will stop generating text even in the middle of a sentence if `finish_reason` is `length` (for OpenAI) or `stop_reason` is `max_tokens` (for Anthropic). To obtain longer outputs from LLM, we should call the message generation API multiple times and merge the results into the text to circumvent the API's output token limit. The extra line breaks forced by the `merge_message_runs` function when seamlessly merging messages can be annoying, so I added the option to specify the chunk separator. **Issue:** No corresponding issues. **Dependencies:** No dependencies required. **Twitter handle:** @hanama_chem https://x.com/hanama_chem --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- .../document_loaders/langsmith.py | 2 +- libs/core/langchain_core/messages/utils.py | 17 ++++++++----- .../tests/unit_tests/messages/test_utils.py | 24 +++++++++++++++++++ 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/libs/core/langchain_core/document_loaders/langsmith.py b/libs/core/langchain_core/document_loaders/langsmith.py index 232da98ccf7..9da48851d06 100644 --- a/libs/core/langchain_core/document_loaders/langsmith.py +++ b/libs/core/langchain_core/document_loaders/langsmith.py @@ -73,7 +73,7 @@ class LangSmithLoader(BaseLoader): inline_s3_urls: Whether to inline S3 URLs. Defaults to True. offset: The offset to start from. Defaults to 0. limit: The maximum number of examples to return. - filter: A structured fileter string to apply to the examples. + filter: A structured filter string to apply to the examples. client: LangSmith Client. If not provided will be initialized from below args. client_kwargs: Keyword args to pass to LangSmith client init. Should only be specified if ``client`` isn't. diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index c7d4d58a149..6f88fcbcf79 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -443,6 +443,8 @@ def filter_messages( @_runnable_support def merge_message_runs( messages: Union[Iterable[MessageLikeRepresentation], PromptValue], + *, + chunk_separator: str = "\n", ) -> List[BaseMessage]: """Merge consecutive Messages of the same type. @@ -451,13 +453,16 @@ def merge_message_runs( Args: messages: Sequence Message-like objects to merge. + chunk_separator: Specify the string to be inserted between message chunks. + Default is "\n". Returns: List of BaseMessages with consecutive runs of message types merged into single - messages. If two messages being merged both have string contents, the merged - content is a concatenation of the two strings with a new-line separator. If at - least one of the messages has a list of content blocks, the merged content is a - list of content blocks. + messages. By default, if two messages being merged both have string contents, + the merged content is a concatenation of the two strings with a new-line separator. + The separator inserted between message chunks can be controlled by specifying + any string with ``chunk_separator``. If at least one of the messages has a list of + content blocks, the merged content is a list of content blocks. Example: .. code-block:: python @@ -527,7 +532,7 @@ def merge_message_runs( and last_chunk.content and curr_chunk.content ): - last_chunk.content += "\n" + last_chunk.content += chunk_separator merged.append(_chunk_to_msg(last_chunk + curr_chunk)) return merged @@ -799,7 +804,7 @@ def trim_messages( list_token_counter = token_counter # type: ignore[assignment] else: raise ValueError( - f"'token_counter' expected ot be a model that implements " + f"'token_counter' expected to be a model that implements " f"'get_num_tokens_from_messages()' or a function. Received object of type " f"{type(token_counter)}." ) diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 3f25e02fb23..56b8c0df7be 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -30,6 +30,30 @@ def test_merge_message_runs_str(msg_cls: Type[BaseMessage]) -> None: assert messages == messages_copy +@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage]) +def test_merge_message_runs_str_with_specified_separator( + msg_cls: Type[BaseMessage], +) -> None: + messages = [msg_cls("foo"), msg_cls("bar"), msg_cls("baz")] + messages_copy = [m.copy(deep=True) for m in messages] + expected = [msg_cls("foobarbaz")] + actual = merge_message_runs(messages, chunk_separator="") + assert actual == expected + assert messages == messages_copy + + +@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage]) +def test_merge_message_runs_str_without_separator( + msg_cls: Type[BaseMessage], +) -> None: + messages = [msg_cls("foo"), msg_cls("bar"), msg_cls("baz")] + messages_copy = [m.copy(deep=True) for m in messages] + expected = [msg_cls("foobarbaz")] + actual = merge_message_runs(messages, chunk_separator="") + assert actual == expected + assert messages == messages_copy + + def test_merge_message_runs_content() -> None: messages = [ AIMessage("foo", id="1"),