add helper function

This commit is contained in:
Sydney Runkle 2025-05-13 14:27:41 -07:00
parent b226701b58
commit 2687eb10db
2 changed files with 41 additions and 35 deletions

View File

@ -66,6 +66,7 @@ from langchain_core.outputs import (
LLMResult,
RunInfo,
)
from langchain_core.outputs.chat_generation import merge_chat_generation_chunks
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
@ -485,7 +486,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
chunks: list[ChatGenerationChunk] = []
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
@ -499,20 +501,25 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
chunks.append(chunk)
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
except BaseException as e:
generations_with_error_metadata = _generate_response_from_error(e)
if generation:
generations = [[generation], generations_with_error_metadata]
chat_generation_chunk = merge_chat_generation_chunks(chunks)
if chat_generation_chunk:
generations = [
[chat_generation_chunk],
generations_with_error_metadata,
]
else:
generations = [generations_with_error_metadata]
run_manager.on_llm_error(e, response=LLMResult(generations=generations)) # type: ignore[arg-type]
run_manager.on_llm_error(
e,
response=LLMResult(generations=generations), # type: ignore[arg-type]
)
raise
generation = merge_chat_generation_chunks(chunks)
if generation is None:
err = ValueError("No generation chunks were returned")
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
@ -575,7 +582,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if self.rate_limiter:
await self.rate_limiter.aacquire(blocking=True)
generations: list[ChatGenerationChunk] = []
chunks: list[ChatGenerationChunk] = []
try:
input_messages = _normalize_messages(messages)
@ -585,46 +592,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs,
):
if chunk.message.id is None:
chunk.message.id = _LC_ID_PREFIX + "-" + str(run_manager.run_id)
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
generations.append(chunk)
chunks.append(chunk)
yield chunk.message
except BaseException as e:
generations_with_error_metadata = _generate_response_from_error(e)
if generations:
if len(generations) > 1:
aggregate_generation = [
[generations[0] + generations[1:]],
generations_with_error_metadata,
]
else:
aggregate_generation = [
[generations[0]],
generations_with_error_metadata,
]
chat_generation_chunk = merge_chat_generation_chunks(chunks)
if chat_generation_chunk:
generations = [[chat_generation_chunk], generations_with_error_metadata]
else:
aggregate_generation = [generations_with_error_metadata]
generations = [generations_with_error_metadata]
await run_manager.on_llm_error(
e,
response=LLMResult(generations=aggregate_generation), # type: ignore[arg-type]
response=LLMResult(generations=generations), # type: ignore[arg-type]
)
raise
if not generations:
generation = merge_chat_generation_chunks(chunks)
if not generation:
err = ValueError("No generation chunks were returned")
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err
if len(generations) > 1:
aggregate_generation = generations[0] + generations[1:]
else:
aggregate_generation = generations[0]
await run_manager.on_llm_end(
LLMResult(generations=[[aggregate_generation]]),
LLMResult(generations=[[generation]]),
)
# --- Custom methods ---

View File

@ -2,17 +2,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Union
from typing import Literal, Union
from pydantic import model_validator
from typing_extensions import Self
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.utils._merge import merge_dicts
if TYPE_CHECKING:
from typing_extensions import Self
class ChatGeneration(Generation):
"""A single chat generation output.
@ -115,3 +113,16 @@ class ChatGenerationChunk(ChatGeneration):
)
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
raise TypeError(msg)
def merge_chat_generation_chunks(
chunks: list[ChatGenerationChunk],
) -> Union[ChatGenerationChunk, None]:
"""Merge a list of ChatGenerationChunks into a single ChatGenerationChunk."""
if not chunks:
return None
if len(chunks) == 1:
return chunks[0]
return chunks[0] + chunks[1:]