From 94bd4bf31390cf8ae708a38127894c387c7ab4bf Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Tue, 13 May 2025 10:49:24 -0700 Subject: [PATCH] do generation aggregation at the end --- .../language_models/chat_models.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 712271778ae..6e82fc671df 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -575,7 +575,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if self.rate_limiter: await self.rate_limiter.aacquire(blocking=True) - generation: Optional[ChatGenerationChunk] = None + generations: list[ChatGenerationChunk] = [] + try: input_messages = _normalize_messages(messages) async for chunk in self._astream( @@ -584,35 +585,46 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): **kwargs, ): if chunk.message.id is None: - chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}" + chunk.message.id = _LC_ID_PREFIX + "-" + str(run_manager.run_id) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) await run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) + generations.append(chunk) yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk except BaseException as e: generations_with_error_metadata = _generate_response_from_error(e) - if generation: - generations = [[generation], generations_with_error_metadata] + if generations: + if len(generations) > 1: + aggregate_generation = [ + [generations[0] + generations[1:]], + generations_with_error_metadata, + ] + else: + aggregate_generation = [ + [generations[0]], + generations_with_error_metadata, + ] else: - generations = [generations_with_error_metadata] + aggregate_generation = [generations_with_error_metadata] await run_manager.on_llm_error( e, - response=LLMResult(generations=generations), # type: ignore[arg-type] + response=LLMResult(generations=aggregate_generation), # type: ignore[arg-type] ) raise - if generation is None: + if not generations: err = ValueError("No generation chunks were returned") await run_manager.on_llm_error(err, response=LLMResult(generations=[])) raise err + if len(generations) > 1: + aggregate_generation = generations[0] + generations[1:] + else: + aggregate_generation = generations[0] + await run_manager.on_llm_end( - LLMResult(generations=[[generation]]), + LLMResult(generations=[[aggregate_generation]]), ) # --- Custom methods ---