From 263c215112ecfdff8b98f3844345da403677b5e4 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 14 May 2025 08:13:05 -0700 Subject: [PATCH] perf[core]: remove generations summation from hot loop (#31231) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Removes summation of `ChatGenerationChunk` from hot loops in `stream` and `astream` 2. Removes run id gen from loop as well (minor impact) Again, benchmarking on processing ~200k chunks (a poem about broccoli). Before: ~4.2s Blue circle is all the time spent adding up gen chunks Screenshot 2025-05-14 at 7 48 33 AM After: ~2.3s Blue circle is remaining time spent on adding chunks, which can be minimized in a future PR by optimizing the `merge_content`, `merge_dicts`, and `merge_lists` utilities. Screenshot 2025-05-14 at 7 50 08 AM --- .../language_models/chat_models.py | 45 +++++++++++-------- .../langchain_core/outputs/chat_generation.py | 19 ++++++-- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 712271778ae..322b2a4f677 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -66,6 +66,7 @@ from langchain_core.outputs import ( LLMResult, RunInfo, ) +from langchain_core.outputs.chat_generation import merge_chat_generation_chunks from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.rate_limiters import BaseRateLimiter from langchain_core.runnables import RunnableMap, RunnablePassthrough @@ -485,34 +486,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_id=config.pop("run_id", None), batch_size=1, ) - generation: Optional[ChatGenerationChunk] = None + + chunks: list[ChatGenerationChunk] = [] if self.rate_limiter: self.rate_limiter.acquire(blocking=True) try: input_messages = _normalize_messages(messages) + run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id))) for chunk in self._stream(input_messages, stop=stop, **kwargs): if chunk.message.id is None: - chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}" + chunk.message.id = run_id chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) + chunks.append(chunk) yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk except BaseException as e: generations_with_error_metadata = _generate_response_from_error(e) - if generation: - generations = [[generation], generations_with_error_metadata] + chat_generation_chunk = merge_chat_generation_chunks(chunks) + if chat_generation_chunk: + generations = [ + [chat_generation_chunk], + generations_with_error_metadata, + ] else: generations = [generations_with_error_metadata] - run_manager.on_llm_error(e, response=LLMResult(generations=generations)) # type: ignore[arg-type] + run_manager.on_llm_error( + e, + response=LLMResult(generations=generations), # type: ignore[arg-type] + ) raise + generation = merge_chat_generation_chunks(chunks) if generation is None: err = ValueError("No generation chunks were returned") run_manager.on_llm_error(err, response=LLMResult(generations=[])) @@ -575,29 +583,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if self.rate_limiter: await self.rate_limiter.aacquire(blocking=True) - generation: Optional[ChatGenerationChunk] = None + chunks: list[ChatGenerationChunk] = [] + try: input_messages = _normalize_messages(messages) + run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id))) async for chunk in self._astream( input_messages, stop=stop, **kwargs, ): if chunk.message.id is None: - chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}" + chunk.message.id = run_id chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) await run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) + chunks.append(chunk) yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk except BaseException as e: generations_with_error_metadata = _generate_response_from_error(e) - if generation: - generations = [[generation], generations_with_error_metadata] + chat_generation_chunk = merge_chat_generation_chunks(chunks) + if chat_generation_chunk: + generations = [[chat_generation_chunk], generations_with_error_metadata] else: generations = [generations_with_error_metadata] await run_manager.on_llm_error( @@ -606,7 +614,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) raise - if generation is None: + generation = merge_chat_generation_chunks(chunks) + if not generation: err = ValueError("No generation chunks were returned") await run_manager.on_llm_error(err, response=LLMResult(generations=[])) raise err diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index d01ae516942..25ea5684e67 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -2,17 +2,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Union +from typing import Literal, Union from pydantic import model_validator +from typing_extensions import Self from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.outputs.generation import Generation from langchain_core.utils._merge import merge_dicts -if TYPE_CHECKING: - from typing_extensions import Self - class ChatGeneration(Generation): """A single chat generation output. @@ -115,3 +113,16 @@ class ChatGenerationChunk(ChatGeneration): ) msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" raise TypeError(msg) + + +def merge_chat_generation_chunks( + chunks: list[ChatGenerationChunk], +) -> Union[ChatGenerationChunk, None]: + """Merge a list of ChatGenerationChunks into a single ChatGenerationChunk.""" + if not chunks: + return None + + if len(chunks) == 1: + return chunks[0] + + return chunks[0] + chunks[1:]