do generation aggregation at the end

This commit is contained in:
Sydney Runkle 2025-05-13 10:49:24 -07:00
parent 31ba2844d3
commit 94bd4bf313

View File

@ -575,7 +575,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if self.rate_limiter: if self.rate_limiter:
await self.rate_limiter.aacquire(blocking=True) await self.rate_limiter.aacquire(blocking=True)
generation: Optional[ChatGenerationChunk] = None generations: list[ChatGenerationChunk] = []
try: try:
input_messages = _normalize_messages(messages) input_messages = _normalize_messages(messages)
async for chunk in self._astream( async for chunk in self._astream(
@ -584,35 +585,46 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs, **kwargs,
): ):
if chunk.message.id is None: if chunk.message.id is None:
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}" chunk.message.id = _LC_ID_PREFIX + "-" + str(run_manager.run_id)
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
await run_manager.on_llm_new_token( await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk cast("str", chunk.message.content), chunk=chunk
) )
generations.append(chunk)
yield chunk.message yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
except BaseException as e: except BaseException as e:
generations_with_error_metadata = _generate_response_from_error(e) generations_with_error_metadata = _generate_response_from_error(e)
if generation: if generations:
generations = [[generation], generations_with_error_metadata] if len(generations) > 1:
aggregate_generation = [
[generations[0] + generations[1:]],
generations_with_error_metadata,
]
else: else:
generations = [generations_with_error_metadata] aggregate_generation = [
[generations[0]],
generations_with_error_metadata,
]
else:
aggregate_generation = [generations_with_error_metadata]
await run_manager.on_llm_error( await run_manager.on_llm_error(
e, e,
response=LLMResult(generations=generations), # type: ignore[arg-type] response=LLMResult(generations=aggregate_generation), # type: ignore[arg-type]
) )
raise raise
if generation is None: if not generations:
err = ValueError("No generation chunks were returned") err = ValueError("No generation chunks were returned")
await run_manager.on_llm_error(err, response=LLMResult(generations=[])) await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err raise err
if len(generations) > 1:
aggregate_generation = generations[0] + generations[1:]
else:
aggregate_generation = generations[0]
await run_manager.on_llm_end( await run_manager.on_llm_end(
LLMResult(generations=[[generation]]), LLMResult(generations=[[aggregate_generation]]),
) )
# --- Custom methods --- # --- Custom methods ---