mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
perf[core]: remove generations summation from hot loop (#31231)
1. Removes summation of `ChatGenerationChunk` from hot loops in `stream` and `astream` 2. Removes run id gen from loop as well (minor impact) Again, benchmarking on processing ~200k chunks (a poem about broccoli). Before: ~4.2s Blue circle is all the time spent adding up gen chunks <img width="1345" alt="Screenshot 2025-05-14 at 7 48 33 AM" src="https://github.com/user-attachments/assets/08a59d78-134d-4cd3-9d54-214de689df51" /> After: ~2.3s Blue circle is remaining time spent on adding chunks, which can be minimized in a future PR by optimizing the `merge_content`, `merge_dicts`, and `merge_lists` utilities. <img width="1353" alt="Screenshot 2025-05-14 at 7 50 08 AM" src="https://github.com/user-attachments/assets/df6b3506-929e-4b6d-b198-7c4e992c6d34" />
This commit is contained in:
parent
17b799860f
commit
263c215112
@ -66,6 +66,7 @@ from langchain_core.outputs import (
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.outputs.chat_generation import merge_chat_generation_chunks
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
@ -485,34 +486,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
for chunk in self._stream(input_messages, stop=stop, **kwargs):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.id = run_id
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
chunks.append(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
except BaseException as e:
|
||||
generations_with_error_metadata = _generate_response_from_error(e)
|
||||
if generation:
|
||||
generations = [[generation], generations_with_error_metadata]
|
||||
chat_generation_chunk = merge_chat_generation_chunks(chunks)
|
||||
if chat_generation_chunk:
|
||||
generations = [
|
||||
[chat_generation_chunk],
|
||||
generations_with_error_metadata,
|
||||
]
|
||||
else:
|
||||
generations = [generations_with_error_metadata]
|
||||
run_manager.on_llm_error(e, response=LLMResult(generations=generations)) # type: ignore[arg-type]
|
||||
run_manager.on_llm_error(
|
||||
e,
|
||||
response=LLMResult(generations=generations), # type: ignore[arg-type]
|
||||
)
|
||||
raise
|
||||
|
||||
generation = merge_chat_generation_chunks(chunks)
|
||||
if generation is None:
|
||||
err = ValueError("No generation chunks were returned")
|
||||
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
|
||||
@ -575,29 +583,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
async for chunk in self._astream(
|
||||
input_messages,
|
||||
stop=stop,
|
||||
**kwargs,
|
||||
):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.id = run_id
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
chunks.append(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
except BaseException as e:
|
||||
generations_with_error_metadata = _generate_response_from_error(e)
|
||||
if generation:
|
||||
generations = [[generation], generations_with_error_metadata]
|
||||
chat_generation_chunk = merge_chat_generation_chunks(chunks)
|
||||
if chat_generation_chunk:
|
||||
generations = [[chat_generation_chunk], generations_with_error_metadata]
|
||||
else:
|
||||
generations = [generations_with_error_metadata]
|
||||
await run_manager.on_llm_error(
|
||||
@ -606,7 +614,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
raise
|
||||
|
||||
if generation is None:
|
||||
generation = merge_chat_generation_chunks(chunks)
|
||||
if not generation:
|
||||
err = ValueError("No generation chunks were returned")
|
||||
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
|
||||
raise err
|
||||
|
@ -2,17 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output.
|
||||
@ -115,3 +113,16 @@ class ChatGenerationChunk(ChatGeneration):
|
||||
)
|
||||
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def merge_chat_generation_chunks(
|
||||
chunks: list[ChatGenerationChunk],
|
||||
) -> Union[ChatGenerationChunk, None]:
|
||||
"""Merge a list of ChatGenerationChunks into a single ChatGenerationChunk."""
|
||||
if not chunks:
|
||||
return None
|
||||
|
||||
if len(chunks) == 1:
|
||||
return chunks[0]
|
||||
|
||||
return chunks[0] + chunks[1:]
|
||||
|
Loading…
Reference in New Issue
Block a user