mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
do generation aggregation at the end
This commit is contained in:
parent
31ba2844d3
commit
94bd4bf313
@ -575,7 +575,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
if self.rate_limiter:
|
if self.rate_limiter:
|
||||||
await self.rate_limiter.aacquire(blocking=True)
|
await self.rate_limiter.aacquire(blocking=True)
|
||||||
|
|
||||||
generation: Optional[ChatGenerationChunk] = None
|
generations: list[ChatGenerationChunk] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
input_messages = _normalize_messages(messages)
|
input_messages = _normalize_messages(messages)
|
||||||
async for chunk in self._astream(
|
async for chunk in self._astream(
|
||||||
@ -584,35 +585,46 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if chunk.message.id is None:
|
if chunk.message.id is None:
|
||||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
chunk.message.id = _LC_ID_PREFIX + "-" + str(run_manager.run_id)
|
||||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||||
await run_manager.on_llm_new_token(
|
await run_manager.on_llm_new_token(
|
||||||
cast("str", chunk.message.content), chunk=chunk
|
cast("str", chunk.message.content), chunk=chunk
|
||||||
)
|
)
|
||||||
|
generations.append(chunk)
|
||||||
yield chunk.message
|
yield chunk.message
|
||||||
if generation is None:
|
|
||||||
generation = chunk
|
|
||||||
else:
|
|
||||||
generation += chunk
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
generations_with_error_metadata = _generate_response_from_error(e)
|
generations_with_error_metadata = _generate_response_from_error(e)
|
||||||
if generation:
|
if generations:
|
||||||
generations = [[generation], generations_with_error_metadata]
|
if len(generations) > 1:
|
||||||
|
aggregate_generation = [
|
||||||
|
[generations[0] + generations[1:]],
|
||||||
|
generations_with_error_metadata,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
aggregate_generation = [
|
||||||
|
[generations[0]],
|
||||||
|
generations_with_error_metadata,
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
generations = [generations_with_error_metadata]
|
aggregate_generation = [generations_with_error_metadata]
|
||||||
await run_manager.on_llm_error(
|
await run_manager.on_llm_error(
|
||||||
e,
|
e,
|
||||||
response=LLMResult(generations=generations), # type: ignore[arg-type]
|
response=LLMResult(generations=aggregate_generation), # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if generation is None:
|
if not generations:
|
||||||
err = ValueError("No generation chunks were returned")
|
err = ValueError("No generation chunks were returned")
|
||||||
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
|
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
|
if len(generations) > 1:
|
||||||
|
aggregate_generation = generations[0] + generations[1:]
|
||||||
|
else:
|
||||||
|
aggregate_generation = generations[0]
|
||||||
|
|
||||||
await run_manager.on_llm_end(
|
await run_manager.on_llm_end(
|
||||||
LLMResult(generations=[[generation]]),
|
LLMResult(generations=[[aggregate_generation]]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Custom methods ---
|
# --- Custom methods ---
|
||||||
|
Loading…
Reference in New Issue
Block a user