mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 19:48:26 +00:00
do generation aggregation at the end
This commit is contained in:
parent
31ba2844d3
commit
94bd4bf313
@ -575,7 +575,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
generations: list[ChatGenerationChunk] = []
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
async for chunk in self._astream(
|
||||
@ -584,35 +585,46 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
**kwargs,
|
||||
):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.id = _LC_ID_PREFIX + "-" + str(run_manager.run_id)
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
generations.append(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
except BaseException as e:
|
||||
generations_with_error_metadata = _generate_response_from_error(e)
|
||||
if generation:
|
||||
generations = [[generation], generations_with_error_metadata]
|
||||
if generations:
|
||||
if len(generations) > 1:
|
||||
aggregate_generation = [
|
||||
[generations[0] + generations[1:]],
|
||||
generations_with_error_metadata,
|
||||
]
|
||||
else:
|
||||
aggregate_generation = [
|
||||
[generations[0]],
|
||||
generations_with_error_metadata,
|
||||
]
|
||||
else:
|
||||
generations = [generations_with_error_metadata]
|
||||
aggregate_generation = [generations_with_error_metadata]
|
||||
await run_manager.on_llm_error(
|
||||
e,
|
||||
response=LLMResult(generations=generations), # type: ignore[arg-type]
|
||||
response=LLMResult(generations=aggregate_generation), # type: ignore[arg-type]
|
||||
)
|
||||
raise
|
||||
|
||||
if generation is None:
|
||||
if not generations:
|
||||
err = ValueError("No generation chunks were returned")
|
||||
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
|
||||
raise err
|
||||
|
||||
if len(generations) > 1:
|
||||
aggregate_generation = generations[0] + generations[1:]
|
||||
else:
|
||||
aggregate_generation = generations[0]
|
||||
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
LLMResult(generations=[[aggregate_generation]]),
|
||||
)
|
||||
|
||||
# --- Custom methods ---
|
||||
|
Loading…
Reference in New Issue
Block a user