do generation aggregation at the end

This commit is contained in:
Sydney Runkle 2025-05-13 10:49:24 -07:00
parent 31ba2844d3
commit 94bd4bf313

View File

@ -575,7 +575,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if self.rate_limiter:
await self.rate_limiter.aacquire(blocking=True)
generation: Optional[ChatGenerationChunk] = None
generations: list[ChatGenerationChunk] = []
try:
input_messages = _normalize_messages(messages)
async for chunk in self._astream(
@ -584,35 +585,46 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs,
):
if chunk.message.id is None:
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
chunk.message.id = _LC_ID_PREFIX + "-" + str(run_manager.run_id)
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
generations.append(chunk)
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
except BaseException as e:
generations_with_error_metadata = _generate_response_from_error(e)
if generation:
generations = [[generation], generations_with_error_metadata]
if generations:
if len(generations) > 1:
aggregate_generation = [
[generations[0] + generations[1:]],
generations_with_error_metadata,
]
else:
aggregate_generation = [
[generations[0]],
generations_with_error_metadata,
]
else:
generations = [generations_with_error_metadata]
aggregate_generation = [generations_with_error_metadata]
await run_manager.on_llm_error(
e,
response=LLMResult(generations=generations), # type: ignore[arg-type]
response=LLMResult(generations=aggregate_generation), # type: ignore[arg-type]
)
raise
if generation is None:
if not generations:
err = ValueError("No generation chunks were returned")
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err
if len(generations) > 1:
aggregate_generation = generations[0] + generations[1:]
else:
aggregate_generation = generations[0]
await run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
LLMResult(generations=[[aggregate_generation]]),
)
# --- Custom methods ---