From 263c215112ecfdff8b98f3844345da403677b5e4 Mon Sep 17 00:00:00 2001
From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Date: Wed, 14 May 2025 08:13:05 -0700
Subject: [PATCH] perf[core]: remove generations summation from hot loop
(#31231)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
1. Removes summation of `ChatGenerationChunk` from hot loops in `stream`
and `astream`
2. Removes run id gen from loop as well (minor impact)
Again, benchmarking on processing ~200k chunks (a poem about broccoli).
Before: ~4.2s
Blue circle is all the time spent adding up gen chunks
After: ~2.3s
Blue circle is remaining time spent on adding chunks, which can be
minimized in a future PR by optimizing the `merge_content`,
`merge_dicts`, and `merge_lists` utilities.
---
.../language_models/chat_models.py | 45 +++++++++++--------
.../langchain_core/outputs/chat_generation.py | 19 ++++++--
2 files changed, 42 insertions(+), 22 deletions(-)
diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py
index 712271778ae..322b2a4f677 100644
--- a/libs/core/langchain_core/language_models/chat_models.py
+++ b/libs/core/langchain_core/language_models/chat_models.py
@@ -66,6 +66,7 @@ from langchain_core.outputs import (
LLMResult,
RunInfo,
)
+from langchain_core.outputs.chat_generation import merge_chat_generation_chunks
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
@@ -485,34 +486,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_id=config.pop("run_id", None),
batch_size=1,
)
- generation: Optional[ChatGenerationChunk] = None
+
+ chunks: list[ChatGenerationChunk] = []
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
try:
input_messages = _normalize_messages(messages)
+ run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
for chunk in self._stream(input_messages, stop=stop, **kwargs):
if chunk.message.id is None:
- chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
+ chunk.message.id = run_id
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
+ chunks.append(chunk)
yield chunk.message
- if generation is None:
- generation = chunk
- else:
- generation += chunk
except BaseException as e:
generations_with_error_metadata = _generate_response_from_error(e)
- if generation:
- generations = [[generation], generations_with_error_metadata]
+ chat_generation_chunk = merge_chat_generation_chunks(chunks)
+ if chat_generation_chunk:
+ generations = [
+ [chat_generation_chunk],
+ generations_with_error_metadata,
+ ]
else:
generations = [generations_with_error_metadata]
- run_manager.on_llm_error(e, response=LLMResult(generations=generations)) # type: ignore[arg-type]
+ run_manager.on_llm_error(
+ e,
+ response=LLMResult(generations=generations), # type: ignore[arg-type]
+ )
raise
+ generation = merge_chat_generation_chunks(chunks)
if generation is None:
err = ValueError("No generation chunks were returned")
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
@@ -575,29 +583,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if self.rate_limiter:
await self.rate_limiter.aacquire(blocking=True)
- generation: Optional[ChatGenerationChunk] = None
+ chunks: list[ChatGenerationChunk] = []
+
try:
input_messages = _normalize_messages(messages)
+ run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
async for chunk in self._astream(
input_messages,
stop=stop,
**kwargs,
):
if chunk.message.id is None:
- chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
+ chunk.message.id = run_id
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
+ chunks.append(chunk)
yield chunk.message
- if generation is None:
- generation = chunk
- else:
- generation += chunk
except BaseException as e:
generations_with_error_metadata = _generate_response_from_error(e)
- if generation:
- generations = [[generation], generations_with_error_metadata]
+ chat_generation_chunk = merge_chat_generation_chunks(chunks)
+ if chat_generation_chunk:
+ generations = [[chat_generation_chunk], generations_with_error_metadata]
else:
generations = [generations_with_error_metadata]
await run_manager.on_llm_error(
@@ -606,7 +614,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
raise
- if generation is None:
+ generation = merge_chat_generation_chunks(chunks)
+ if not generation:
err = ValueError("No generation chunks were returned")
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err
diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py
index d01ae516942..25ea5684e67 100644
--- a/libs/core/langchain_core/outputs/chat_generation.py
+++ b/libs/core/langchain_core/outputs/chat_generation.py
@@ -2,17 +2,15 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Literal, Union
+from typing import Literal, Union
from pydantic import model_validator
+from typing_extensions import Self
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.utils._merge import merge_dicts
-if TYPE_CHECKING:
- from typing_extensions import Self
-
class ChatGeneration(Generation):
"""A single chat generation output.
@@ -115,3 +113,16 @@ class ChatGenerationChunk(ChatGeneration):
)
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
raise TypeError(msg)
+
+
+def merge_chat_generation_chunks(
+ chunks: list[ChatGenerationChunk],
+) -> Union[ChatGenerationChunk, None]:
+ """Merge a list of ChatGenerationChunks into a single ChatGenerationChunk."""
+ if not chunks:
+ return None
+
+ if len(chunks) == 1:
+ return chunks[0]
+
+ return chunks[0] + chunks[1:]