diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 712271778ae..322b2a4f677 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -66,6 +66,7 @@ from langchain_core.outputs import ( LLMResult, RunInfo, ) +from langchain_core.outputs.chat_generation import merge_chat_generation_chunks from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.rate_limiters import BaseRateLimiter from langchain_core.runnables import RunnableMap, RunnablePassthrough @@ -485,34 +486,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_id=config.pop("run_id", None), batch_size=1, ) - generation: Optional[ChatGenerationChunk] = None + + chunks: list[ChatGenerationChunk] = [] if self.rate_limiter: self.rate_limiter.acquire(blocking=True) try: input_messages = _normalize_messages(messages) + run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id))) for chunk in self._stream(input_messages, stop=stop, **kwargs): if chunk.message.id is None: - chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}" + chunk.message.id = run_id chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) + chunks.append(chunk) yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk except BaseException as e: generations_with_error_metadata = _generate_response_from_error(e) - if generation: - generations = [[generation], generations_with_error_metadata] + chat_generation_chunk = merge_chat_generation_chunks(chunks) + if chat_generation_chunk: + generations = [ + [chat_generation_chunk], + generations_with_error_metadata, + ] else: generations = [generations_with_error_metadata] - run_manager.on_llm_error(e, response=LLMResult(generations=generations)) # type: ignore[arg-type] + run_manager.on_llm_error( + e, + response=LLMResult(generations=generations), # type: ignore[arg-type] + ) raise + generation = merge_chat_generation_chunks(chunks) if generation is None: err = ValueError("No generation chunks were returned") run_manager.on_llm_error(err, response=LLMResult(generations=[])) @@ -575,29 +583,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if self.rate_limiter: await self.rate_limiter.aacquire(blocking=True) - generation: Optional[ChatGenerationChunk] = None + chunks: list[ChatGenerationChunk] = [] + try: input_messages = _normalize_messages(messages) + run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id))) async for chunk in self._astream( input_messages, stop=stop, **kwargs, ): if chunk.message.id is None: - chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}" + chunk.message.id = run_id chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) await run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) + chunks.append(chunk) yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk except BaseException as e: generations_with_error_metadata = _generate_response_from_error(e) - if generation: - generations = [[generation], generations_with_error_metadata] + chat_generation_chunk = merge_chat_generation_chunks(chunks) + if chat_generation_chunk: + generations = [[chat_generation_chunk], generations_with_error_metadata] else: generations = [generations_with_error_metadata] await run_manager.on_llm_error( @@ -606,7 +614,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) raise - if generation is None: + generation = merge_chat_generation_chunks(chunks) + if not generation: err = ValueError("No generation chunks were returned") await run_manager.on_llm_error(err, response=LLMResult(generations=[])) raise err diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index d01ae516942..25ea5684e67 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -2,17 +2,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Union +from typing import Literal, Union from pydantic import model_validator +from typing_extensions import Self from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.outputs.generation import Generation from langchain_core.utils._merge import merge_dicts -if TYPE_CHECKING: - from typing_extensions import Self - class ChatGeneration(Generation): """A single chat generation output. @@ -115,3 +113,16 @@ class ChatGenerationChunk(ChatGeneration): ) msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" raise TypeError(msg) + + +def merge_chat_generation_chunks( + chunks: list[ChatGenerationChunk], +) -> Union[ChatGenerationChunk, None]: + """Merge a list of ChatGenerationChunks into a single ChatGenerationChunk.""" + if not chunks: + return None + + if len(chunks) == 1: + return chunks[0] + + return chunks[0] + chunks[1:]