From 2687eb10db5f6251dac6150df4d317456df8ac0b Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Tue, 13 May 2025 14:27:41 -0700 Subject: [PATCH] add helper function --- .../language_models/chat_models.py | 57 +++++++++---------- .../langchain_core/outputs/chat_generation.py | 19 +++++-- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 6e82fc671df..f1c01a7712a 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -66,6 +66,7 @@ from langchain_core.outputs import ( LLMResult, RunInfo, ) +from langchain_core.outputs.chat_generation import merge_chat_generation_chunks from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.rate_limiters import BaseRateLimiter from langchain_core.runnables import RunnableMap, RunnablePassthrough @@ -485,7 +486,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_id=config.pop("run_id", None), batch_size=1, ) - generation: Optional[ChatGenerationChunk] = None + + chunks: list[ChatGenerationChunk] = [] if self.rate_limiter: self.rate_limiter.acquire(blocking=True) @@ -499,20 +501,25 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) + chunks.append(chunk) yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk except BaseException as e: generations_with_error_metadata = _generate_response_from_error(e) - if generation: - generations = [[generation], generations_with_error_metadata] + chat_generation_chunk = merge_chat_generation_chunks(chunks) + if chat_generation_chunk: + generations = [ + [chat_generation_chunk], + generations_with_error_metadata, + ] else: generations = [generations_with_error_metadata] - run_manager.on_llm_error(e, response=LLMResult(generations=generations)) # type: ignore[arg-type] + run_manager.on_llm_error( + e, + response=LLMResult(generations=generations), # type: ignore[arg-type] + ) raise + generation = merge_chat_generation_chunks(chunks) if generation is None: err = ValueError("No generation chunks were returned") run_manager.on_llm_error(err, response=LLMResult(generations=[])) @@ -575,7 +582,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if self.rate_limiter: await self.rate_limiter.aacquire(blocking=True) - generations: list[ChatGenerationChunk] = [] + chunks: list[ChatGenerationChunk] = [] try: input_messages = _normalize_messages(messages) @@ -585,46 +592,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): **kwargs, ): if chunk.message.id is None: - chunk.message.id = _LC_ID_PREFIX + "-" + str(run_manager.run_id) + chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}" chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) await run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) - generations.append(chunk) + chunks.append(chunk) yield chunk.message except BaseException as e: generations_with_error_metadata = _generate_response_from_error(e) - if generations: - if len(generations) > 1: - aggregate_generation = [ - [generations[0] + generations[1:]], - generations_with_error_metadata, - ] - else: - aggregate_generation = [ - [generations[0]], - generations_with_error_metadata, - ] + chat_generation_chunk = merge_chat_generation_chunks(chunks) + if chat_generation_chunk: + generations = [[chat_generation_chunk], generations_with_error_metadata] else: - aggregate_generation = [generations_with_error_metadata] + generations = [generations_with_error_metadata] await run_manager.on_llm_error( e, - response=LLMResult(generations=aggregate_generation), # type: ignore[arg-type] + response=LLMResult(generations=generations), # type: ignore[arg-type] ) raise - if not generations: + generation = merge_chat_generation_chunks(chunks) + if not generation: err = ValueError("No generation chunks were returned") await run_manager.on_llm_error(err, response=LLMResult(generations=[])) raise err - if len(generations) > 1: - aggregate_generation = generations[0] + generations[1:] - else: - aggregate_generation = generations[0] - await run_manager.on_llm_end( - LLMResult(generations=[[aggregate_generation]]), + LLMResult(generations=[[generation]]), ) # --- Custom methods --- diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index d01ae516942..25ea5684e67 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -2,17 +2,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Union +from typing import Literal, Union from pydantic import model_validator +from typing_extensions import Self from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.outputs.generation import Generation from langchain_core.utils._merge import merge_dicts -if TYPE_CHECKING: - from typing_extensions import Self - class ChatGeneration(Generation): """A single chat generation output. @@ -115,3 +113,16 @@ class ChatGenerationChunk(ChatGeneration): ) msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" raise TypeError(msg) + + +def merge_chat_generation_chunks( + chunks: list[ChatGenerationChunk], +) -> Union[ChatGenerationChunk, None]: + """Merge a list of ChatGenerationChunks into a single ChatGenerationChunk.""" + if not chunks: + return None + + if len(chunks) == 1: + return chunks[0] + + return chunks[0] + chunks[1:]