fix(core): lint standard outputs branch (#32311)

This commit is contained in:
ccurme 2025-07-29 16:38:45 -03:00 committed by GitHub
parent 9507d0f21c
commit 8cf97e838c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 384 additions and 178 deletions

View File

@ -66,7 +66,7 @@ class LLMManagerMixin:
def on_llm_new_token(
self,
token: Union[str, AIMessageChunk],
token: str,
*,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
@ -265,7 +265,7 @@ class CallbackManagerMixin:
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -516,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -545,7 +545,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_new_token(
self,
token: Union[str, AIMessageChunk],
token: str,
*,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]

View File

@ -11,7 +11,6 @@ from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from contextvars import copy_context
from dataclasses import is_dataclass
from typing import (
TYPE_CHECKING,
Any,
@ -38,7 +37,13 @@ from langchain_core.callbacks.base import (
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import (
AIMessage,
AIMessageChunk,
MessageV1,
MessageV1Types,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
from langchain_core.tracers.schemas import Run
from langchain_core.utils.env import env_var_is_set
@ -249,17 +254,31 @@ def shielded(func: Func) -> Func:
def _convert_llm_events(
event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
if event_name == "on_chat_model_start" and isinstance(args[1], list):
for idx, item in enumerate(args[1]):
if is_dataclass(item):
args[1][idx] = item # convert to old message
elif event_name == "on_llm_new_token" and is_dataclass(args[0]):
kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0])
args[0] = args[0].text
elif event_name == "on_llm_end" and is_dataclass(args[0]):
args[0] = LLMResult(
if (
event_name == "on_chat_model_start"
and isinstance(args[1], list)
and args[1]
and isinstance(args[1][0], MessageV1Types)
):
batch = [
convert_from_v1_message(item)
for item in args[1]
if isinstance(item, MessageV1Types)
]
args[1] = [batch] # type: ignore[index]
elif (
event_name == "on_llm_new_token"
and "chunk" in kwargs
and isinstance(kwargs["chunk"], MessageV1Types)
):
chunk = kwargs["chunk"]
kwargs["chunk"] = ChatGenerationChunk(text=chunk.text, message=chunk)
elif event_name == "on_llm_end" and isinstance(args[0], MessageV1Types):
args[0] = LLMResult( # type: ignore[index]
generations=[[ChatGeneration(text=args[0].text, message=args[0])]]
)
else:
return
def handle_event(
@ -695,7 +714,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
def on_llm_new_token(
self,
token: Union[str, AIMessageChunk],
token: str,
*,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
@ -793,7 +812,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_new_token(
self,
token: Union[str, AIMessageChunk],
token: str,
*,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
@ -1380,7 +1399,7 @@ class CallbackManager(BaseCallbackManager):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> list[CallbackManagerForLLMRun]:
@ -1388,7 +1407,7 @@ class CallbackManager(BaseCallbackManager):
Args:
serialized (dict[str, Any]): The serialized LLM.
messages (list[list[BaseMessage]]): The list of messages.
messages (list[list[BaseMessage | MessageV1]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
@ -1890,7 +1909,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> list[AsyncCallbackManagerForLLMRun]:
@ -1898,7 +1917,7 @@ class AsyncCallbackManager(BaseCallbackManager):
Args:
serialized (dict[str, Any]): The serialized LLM.
messages (list[list[BaseMessage]]): The list of messages.
messages (list[list[BaseMessage | MessageV1]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union
from typing_extensions import override
@ -12,6 +12,7 @@ from langchain_core.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.messages.v1 import AIMessage, MessageV1
from langchain_core.outputs import LLMResult
@ -32,7 +33,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running.
@ -54,7 +55,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
sys.stdout.write(token)
sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
"""Run when LLM ends running.
Args:

View File

@ -4,13 +4,15 @@ import threading
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Optional
from typing import Any, Optional, Union
from typing_extensions import override
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.messages.ai import UsageMetadata, add_usage
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ChatGeneration, LLMResult
@ -58,9 +60,17 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
return str(self.usage_metadata)
@override
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_end(
self, response: Union[LLMResult, AIMessageV1], **kwargs: Any
) -> None:
"""Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2)
if isinstance(response, AIMessageV1):
response = LLMResult(
generations=[
[ChatGeneration(message=convert_from_v1_message(response))]
]
)
try:
generation = response.generations[0][0]
except IndexError:

View File

@ -4,6 +4,7 @@ from collections.abc import Sequence
from typing import Optional
from langchain_core.messages import BaseMessage
from langchain_core.messages.v1 import MessageV1
def _is_openai_data_block(block: dict) -> bool:
@ -129,10 +130,7 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
and _is_openai_data_block(block)
):
if formatted_message is message:
if isinstance(message, BaseMessage):
formatted_message = message.model_copy()
else:
formatted_message = copy.copy(message)
formatted_message = message.model_copy()
# Also shallow-copy content
formatted_message.content = list(formatted_message.content)
@ -142,3 +140,37 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
formatted_messages.append(formatted_message)
return formatted_messages
def _normalize_messages_v1(messages: Sequence[MessageV1]) -> list[MessageV1]:
"""Extend support for message formats.
Chat models implement support for images in OpenAI Chat Completions format, as well
as other multimodal data as standard data blocks. This function extends support to
audio and file data in OpenAI Chat Completions format by converting them to standard
data blocks.
"""
formatted_messages = []
for message in messages:
formatted_message = message
if isinstance(message.content, list):
for idx, block in enumerate(message.content):
if (
isinstance(block, dict)
# Subset to (PDF) files and audio, as most relevant chat models
# support images in OAI format (and some may not yet support the
# standard data block format)
and block.get("type") in {"file", "input_audio"}
and _is_openai_data_block(block) # type: ignore[arg-type]
):
if formatted_message is message:
formatted_message = copy.copy(message)
# Also shallow-copy content
formatted_message.content = list(formatted_message.content)
formatted_message.content[idx] = ( # type: ignore[call-overload]
_convert_openai_format_to_data_block(block) # type: ignore[arg-type]
)
formatted_messages.append(formatted_message)
return formatted_messages

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from abc import ABC
import warnings
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import cache
from typing import (
@ -25,6 +26,7 @@ from langchain_core.messages import (
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
get_buffer_string,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.prompt_values import PromptValue
@ -164,6 +166,7 @@ class BaseLanguageModel(
list[AnyMessage],
]
@abstractmethod
def generate_prompt(
self,
prompts: list[PromptValue],
@ -198,6 +201,7 @@ class BaseLanguageModel(
prompt and additional model provider-specific output.
"""
@abstractmethod
async def agenerate_prompt(
self,
prompts: list[PromptValue],
@ -241,6 +245,7 @@ class BaseLanguageModel(
raise NotImplementedError
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@ -261,6 +266,7 @@ class BaseLanguageModel(
"""
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
def predict_messages(
self,
messages: list[BaseMessage],
@ -285,6 +291,7 @@ class BaseLanguageModel(
"""
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@abstractmethod
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@ -305,6 +312,7 @@ class BaseLanguageModel(
"""
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@abstractmethod
async def apredict_messages(
self,
messages: list[BaseMessage],
@ -360,6 +368,33 @@ class BaseLanguageModel(
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
@classmethod
def _all_required_field_names(cls) -> set:
"""DEPRECATED: Kept for backwards compatibility.

View File

@ -55,7 +55,6 @@ from langchain_core.messages import (
HumanMessage,
convert_to_messages,
convert_to_openai_image_block,
get_buffer_string,
is_data_content_block,
message_chunk_to_message,
)
@ -1352,33 +1351,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
starter_dict["_type"] = self._llm_type
return starter_dict
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
def bind_tools(
self,
tools: Sequence[

View File

@ -6,7 +6,7 @@ import copy
import typing
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
TYPE_CHECKING,
@ -22,29 +22,32 @@ from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
)
from typing_extensions import override
from typing_extensions import TypeAlias, override
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models._utils import _normalize_messages_v1
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
_get_token_ids_default_method,
_get_verbosity,
)
from langchain_core.messages import (
AIMessage,
convert_to_openai_image_block,
get_buffer_string,
is_data_content_block,
)
from langchain_core.messages.utils import (
_convert_from_v1_message,
convert_from_v1_message,
convert_to_messages_v1,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1
@ -58,6 +61,7 @@ from langchain_core.outputs import (
from langchain_core.prompt_values import PromptValue
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.utils.function_calling import (
@ -85,14 +89,15 @@ def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
metadata["status_code"] = response.status_code
if hasattr(error, "request_id"):
metadata["request_id"] = error.request_id
generations = [AIMessageV1(content=[], response_metadata=metadata)]
# Permit response_metadata without model_name, model_provider fields
generations = [AIMessageV1(content=[], response_metadata=metadata)] # type: ignore[arg-type]
else:
generations = []
return generations
def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
def _format_for_tracing(messages: Sequence[MessageV1]) -> list[MessageV1]:
"""Format messages for tracing in on_chat_model_start.
- Update image content blocks to OpenAI Chat Completions format (backward
@ -112,7 +117,7 @@ def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
# Update image content blocks to OpenAI # Chat Completions format.
if (
block["type"] == "image"
and is_data_content_block(block)
and is_data_content_block(block) # type: ignore[arg-type] # permit unnecessary runtime check
and block.get("source_type") != "id"
):
if message_to_trace is message:
@ -120,7 +125,9 @@ def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
message_to_trace = copy.copy(message)
message_to_trace.content = list(message_to_trace.content)
message_to_trace.content[idx] = convert_to_openai_image_block(block)
# TODO: for tracing purposes we store non-standard types (OpenAI format)
# in message content. Consider typing these block formats.
message_to_trace.content[idx] = convert_to_openai_image_block(block) # type: ignore[arg-type, call-overload]
else:
pass
messages_to_trace.append(message_to_trace)
@ -180,7 +187,7 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
return ls_structured_output_format_dict
class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
"""Base class for chat models.
Key imperative methods:
@ -278,12 +285,69 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
does not properly support streaming.
"""
cache: Union[BaseCache, bool, None] = Field(default=None, exclude=True)
"""Whether to cache the response.
* If true, will use the global cache.
* If false, will not use a cache
* If None, will use the global cache if it's set, otherwise no cache.
* If instance of BaseCache, will use the provided cache.
Caching is not currently supported for streaming methods of models.
"""
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
tags: Optional[list[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field(
default=None, exclude=True
)
"""Optional encoder to use for counting tokens."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
# --- Runnable methods ---
@field_validator("verbose", mode="before")
def set_verbose(cls, verbose: Optional[bool]) -> bool: # noqa: FBT001
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
Args:
verbose: The verbosity setting to use.
Returns:
The verbosity setting to use.
"""
if verbose is None:
return _get_verbosity()
return verbose
@property
@override
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
list[MessageV1],
]
@property
@override
def OutputType(self) -> Any:
@ -299,7 +363,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
return convert_to_messages_v1(model_input)
msg = (
f"Invalid input type {type(model_input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
"Must be a PromptValue, str, or list of Messages."
)
raise ValueError(msg)
@ -371,7 +435,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
)
(run_manager,) = callback_manager.on_chat_model_start(
{},
[_format_for_tracing(messages)],
_format_for_tracing(messages),
invocation_params=params,
options=options,
name=config.get("run_name"),
@ -382,27 +446,27 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
input_messages = _normalize_messages(messages)
input_messages = _normalize_messages_v1(messages)
if self._should_stream(async_api=False, **kwargs):
chunks: list[AIMessageChunkV1] = []
try:
for msg in self._stream(input_messages, **kwargs):
run_manager.on_llm_new_token(msg)
run_manager.on_llm_new_token(msg.text or "")
chunks.append(msg)
except BaseException as e:
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
raise
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
full_message = add_ai_message_chunks(chunks[0], *chunks[1:]).to_message()
else:
try:
msg = self._invoke(input_messages, **kwargs)
full_message = self._invoke(input_messages, **kwargs)
except BaseException as e:
run_manager.on_llm_error(e)
raise
run_manager.on_llm_end(msg)
return msg
run_manager.on_llm_end(full_message)
return full_message
@override
async def ainvoke(
@ -410,7 +474,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AIMessage:
) -> AIMessageV1:
config = ensure_config(config)
messages = self._convert_input(input)
ls_structured_output_format = kwargs.pop(
@ -437,7 +501,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
)
(run_manager,) = await callback_manager.on_chat_model_start(
{},
[_format_for_tracing(messages)],
_format_for_tracing(messages),
invocation_params=params,
options=options,
name=config.get("run_name"),
@ -449,31 +513,31 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
await self.rate_limiter.aacquire(blocking=True)
# TODO: type openai image, audio, file types and permit in MessageV1
input_messages = _normalize_messages(messages) # type: ignore[arg-type]
input_messages = _normalize_messages_v1(messages)
if self._should_stream(async_api=True, **kwargs):
chunks: list[AIMessageChunkV1] = []
try:
async for msg in self._astream(input_messages, **kwargs):
await run_manager.on_llm_new_token(msg)
await run_manager.on_llm_new_token(msg.text or "")
chunks.append(msg)
except BaseException as e:
await run_manager.on_llm_error(
e, response=_generate_response_from_error(e)
)
raise
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
full_message = add_ai_message_chunks(chunks[0], *chunks[1:]).to_message()
else:
try:
msg = await self._ainvoke(input_messages, **kwargs)
full_message = await self._ainvoke(input_messages, **kwargs)
except BaseException as e:
await run_manager.on_llm_error(
e, response=_generate_response_from_error(e)
)
raise
await run_manager.on_llm_end(msg.to_message())
return msg
await run_manager.on_llm_end(full_message)
return full_message
@override
def stream(
@ -515,7 +579,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
)
(run_manager,) = callback_manager.on_chat_model_start(
{},
[_format_for_tracing(messages)],
_format_for_tracing(messages),
invocation_params=params,
options=options,
name=config.get("run_name"),
@ -530,9 +594,9 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
try:
# TODO: replace this with something for new messages
input_messages = _normalize_messages(messages)
input_messages = _normalize_messages_v1(messages)
for msg in self._stream(input_messages, **kwargs):
run_manager.on_llm_new_token(msg)
run_manager.on_llm_new_token(msg.text or "")
chunks.append(msg)
yield msg
except BaseException as e:
@ -584,7 +648,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
)
(run_manager,) = await callback_manager.on_chat_model_start(
{},
[_format_for_tracing(messages)],
_format_for_tracing(messages),
invocation_params=params,
options=options,
name=config.get("run_name"),
@ -598,12 +662,12 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
chunks: list[AIMessageChunkV1] = []
try:
input_messages = _normalize_messages(messages)
input_messages = _normalize_messages_v1(messages)
async for msg in self._astream(
input_messages,
**kwargs,
):
await run_manager.on_llm_new_token(msg)
await run_manager.on_llm_new_token(msg.text or "")
chunks.append(msg)
yield msg
except BaseException as e:
@ -623,7 +687,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict:
params = self.dict()
params = self.dump()
params["stop"] = stop
return {**params, **kwargs}
@ -674,14 +738,14 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
self,
messages: list[MessageV1],
**kwargs: Any,
) -> AIMessage:
) -> AIMessageV1:
raise NotImplementedError
async def _ainvoke(
self,
messages: list[MessageV1],
**kwargs: Any,
) -> AIMessage:
) -> AIMessageV1:
return await run_in_executor(
None,
self._invoke,
@ -724,8 +788,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
def _llm_type(self) -> str:
"""Return type of chat model."""
@override
def dict(self, **kwargs: Any) -> dict:
def dump(self, **kwargs: Any) -> dict: # noqa: ARG002
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
@ -903,6 +966,38 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return self.lc_attributes
def get_token_ids(self, text: str) -> list[int]:
"""Return the ordered ids of the tokens in a text.
Args:
text: The string input to tokenize.
Returns:
A list of ids corresponding to the tokens in the text, in order they occur
in the text.
"""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
return _get_token_ids_default_method(text)
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input fits in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(
self,
messages: list[MessageV1],
@ -923,7 +1018,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
Returns:
The sum of the number of tokens across the messages.
"""
messages_v0 = [_convert_from_v1_message(message) for message in messages]
messages_v0 = [convert_from_v1_message(message) for message in messages]
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",

View File

@ -327,7 +327,6 @@ def _create_message_from_message_type_v1(
ValueError: if the message type is not one of "human", "user", "ai",
"assistant", "tool", "system", or "developer".
"""
kwargs: dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
if tool_call_id is not None:
@ -355,7 +354,7 @@ def _create_message_from_message_type_v1(
else:
kwargs["tool_calls"].append(tool_call)
if message_type in {"human", "user"}:
message = HumanMessageV1(content=content, **kwargs)
message: MessageV1 = HumanMessageV1(content=content, **kwargs)
elif message_type in {"ai", "assistant"}:
message = AIMessageV1(content=content, **kwargs)
elif message_type in {"system", "developer"}:
@ -375,7 +374,7 @@ def _create_message_from_message_type_v1(
return message
def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
def convert_from_v1_message(message: MessageV1) -> BaseMessage:
"""Compatibility layer to convert v1 messages to current messages.
Args:
@ -469,7 +468,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
msg_type, msg_content, **msg_kwargs
)
elif isinstance(message, MessageV1Types):
message_ = _convert_from_v1_message(message)
message_ = convert_from_v1_message(message)
else:
msg = f"Unsupported message type: {type(message)}"
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
@ -501,7 +500,7 @@ def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
"""
if isinstance(message, MessageV1Types):
if isinstance(message, AIMessageChunkV1):
message_ = message.to_message()
message_: MessageV1 = message.to_message()
else:
message_ = message
elif isinstance(message, str):

View File

@ -188,7 +188,7 @@ class AIMessage:
@dataclass
class AIMessageChunk:
class AIMessageChunk(AIMessage):
"""A partial chunk of an AI message during streaming.
Represents a portion of an AI response that is delivered incrementally
@ -203,43 +203,7 @@ class AIMessageChunk:
usage_metadata: Optional metadata about token usage and costs.
"""
type: Literal["ai_chunk"] = "ai_chunk"
name: Optional[str] = None
"""An optional name for the message.
This can be used to provide a human-readable name for the message.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
"""
id: Optional[str] = None
"""Unique identifier for the message chunk.
If the provider assigns a meaningful ID, it should be used here.
"""
lc_version: str = "v1"
"""Encoding version for the message."""
content: list[types.ContentBlock] = field(init=False)
usage_metadata: Optional[UsageMetadata] = None
"""If provided, usage metadata for a message chunk, such as token counts.
These data represent incremental usage statistics, as opposed to a running total.
"""
response_metadata: ResponseMetadata = field(init=False)
"""Metadata about the response chunk.
This field should include non-standard data returned by the provider, such as
response headers, service tiers, or log probabilities.
"""
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
"""Auto-parsed message contents, if applicable."""
type: Literal["ai_chunk"] = "ai_chunk" # type: ignore[assignment]
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
@ -342,7 +306,7 @@ class AIMessageChunk:
if isinstance(args_, dict):
tool_calls.append(
create_tool_call(
name=chunk.get("name", ""),
name=chunk.get("name") or "",
args=args_,
id=chunk.get("id", ""),
)

View File

@ -9,7 +9,7 @@ from typing import Annotated, Any, Optional
from pydantic import SkipValidation, ValidationError
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall, ToolCall
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
@ -26,7 +26,7 @@ def parse_tool_call(
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> Optional[ToolCall]:
) -> Optional[dict[str, Any]]:
"""Parse a single tool call.
Args:

View File

@ -16,6 +16,7 @@ from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.exceptions import TracerException # noqa: F401
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1
from langchain_core.tracers.core import _TracerCore
if TYPE_CHECKING:
@ -54,7 +55,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
@ -138,7 +139,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
@ -190,7 +193,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
)
@override
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
def on_llm_end(
self, response: Union[LLMResult, AIMessage], *, run_id: UUID, **kwargs: Any
) -> Run:
"""End a trace for an LLM run.
Args:
@ -562,7 +567,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -617,7 +622,9 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
@ -646,7 +653,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
@override
async def on_llm_end(
self,
response: LLMResult,
response: Union[LLMResult, AIMessage],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -882,7 +889,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]],
) -> None:
"""Process new LLM token."""

View File

@ -18,6 +18,13 @@ from typing import (
from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import (
AIMessage,
AIMessageChunk,
MessageV1,
MessageV1Types,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
@ -156,7 +163,7 @@ class _TracerCore(ABC):
def _create_chat_model_run(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
run_id: UUID,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
@ -181,6 +188,12 @@ class _TracerCore(ABC):
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
if isinstance(messages[0], MessageV1Types):
# Convert from v1 messages to BaseMessage
messages = [
[convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type]
]
messages = cast("list[list[BaseMessage]]", messages)
return Run(
id=run_id,
parent_run_id=parent_run_id,
@ -230,7 +243,9 @@ class _TracerCore(ABC):
self,
token: str,
run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
parent_run_id: Optional[UUID] = None, # noqa: ARG002
) -> Run:
"""Append token event to LLM run and return the run."""
@ -276,7 +291,15 @@ class _TracerCore(ABC):
)
return llm_run
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
def _complete_llm_run(
self, response: Union[LLMResult, AIMessage], run_id: UUID
) -> Run:
if isinstance(response, AIMessage):
response = LLMResult(
generations=[
[ChatGeneration(message=convert_from_v1_message(response))]
]
)
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
if getattr(llm_run, "outputs", None) is None:
llm_run.outputs = {}
@ -558,7 +581,7 @@ class _TracerCore(ABC):
self,
run: Run, # noqa: ARG002
token: str, # noqa: ARG002
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]], # noqa: ARG002
) -> Union[None, Coroutine[Any, Any, None]]:
"""Process new LLM token."""
return None

View File

@ -19,6 +19,7 @@ from typing_extensions import NotRequired, TypedDict, override
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.messages.v1 import MessageV1
from langchain_core.outputs import (
ChatGenerationChunk,
GenerationChunk,
@ -43,6 +44,8 @@ if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from langchain_core.documents import Document
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tracers.log_stream import LogEntry
@ -297,7 +300,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
@ -307,6 +310,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
# below cast is because type is converted in handle_event
messages = cast("list[list[BaseMessage]]", messages)
name_ = _assign_name(name, serialized)
run_type = "chat_model"
@ -407,13 +412,18 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
run_info = self.run_map.get(run_id)
chunk = cast(
"Optional[Union[GenerationChunk, ChatGenerationChunk]]", chunk
) # converted in handle_event
chunk_: Union[GenerationChunk, BaseMessageChunk]
if run_info is None:
@ -456,9 +466,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
@override
async def on_llm_end(
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
self, response: Union[LLMResult, AIMessageV1], *, run_id: UUID, **kwargs: Any
) -> None:
"""End a trace for an LLM run."""
response = cast("LLMResult", response) # converted in handle_event
run_info = self.run_map.pop(run_id)
inputs_ = run_info["inputs"]

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import UUID
from langsmith import Client
@ -21,11 +21,14 @@ from typing_extensions import override
from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.messages.v1 import MessageV1Types
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
from langchain_core.messages.v1 import AIMessageChunk, MessageV1
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
logger = logging.getLogger(__name__)
@ -113,7 +116,7 @@ class LangChainTracer(BaseTracer):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
@ -140,6 +143,12 @@ class LangChainTracer(BaseTracer):
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
if isinstance(messages[0], MessageV1Types):
# Convert from v1 messages to BaseMessage
messages = [
[convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type]
]
messages = cast("list[list[BaseMessage]]", messages)
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,
@ -232,7 +241,9 @@ class LangChainTracer(BaseTracer):
self,
token: str,
run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
parent_run_id: Optional[UUID] = None,
) -> Run:
"""Append token event to LLM run and return the run."""

View File

@ -32,6 +32,7 @@ if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from uuid import UUID
from langchain_core.messages.v1 import AIMessageChunk
from langchain_core.runnables.utils import Input, Output
from langchain_core.tracers.schemas import Run
@ -485,7 +486,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]],
) -> None:
"""Process new LLM token."""
index = self._key_map_by_run_id.get(run.id)

View File

@ -10,6 +10,8 @@ from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
@ -18,7 +20,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -35,7 +37,9 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,

View File

@ -9,6 +9,7 @@ from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.messages.v1 import MessageV1
class BaseFakeCallbackHandler(BaseModel):
@ -285,7 +286,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@ -165,7 +165,7 @@ async def test_callback_handlers() -> None:
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from typing import TYPE_CHECKING, Any
from uuid import uuid4
import pytest
@ -16,6 +16,9 @@ from langchain_core.outputs import LLMResult
from langchain_core.tracers.base import AsyncBaseTracer
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
@ -84,8 +87,9 @@ async def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
tracer = FakeAsyncTracer()
manager = AsyncCallbackManager(handlers=[tracer])
messages: list[list[BaseMessage]] = [[HumanMessage(content="")]]
run_managers = await manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
serialized=SERIALIZED_CHAT, messages=messages
)
compare_run = Run(
id=str(run_managers[0].run_id), # type: ignore[arg-type]

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock
from uuid import uuid4
@ -20,6 +20,9 @@ from langchain_core.runnables import chain as as_runnable
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
@ -89,8 +92,10 @@ def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
# TODO: why is this annotation needed
messages: list[list[BaseMessage]] = [[HumanMessage(content="")]]
run_managers = manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
serialized=SERIALIZED_CHAT, messages=messages
)
compare_run = Run(
id=str(run_managers[0].run_id), # type: ignore[arg-type]

View File

@ -5,6 +5,7 @@ from collections.abc import AsyncIterator
from typing import Any, Literal, Union, cast
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.messages.v1 import AIMessage
from langchain_core.outputs import LLMResult
from typing_extensions import override
@ -44,7 +45,9 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
self.queue.put_nowait(token)
@override
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
async def on_llm_end(
self, response: Union[LLMResult, AIMessage], **kwargs: Any
) -> None:
self.done.set()
@override

View File

@ -1,7 +1,8 @@
from __future__ import annotations
from typing import Any, Optional
from typing import Any, Optional, Union
from langchain_core.messages.v1 import AIMessage
from langchain_core.outputs import LLMResult
from typing_extensions import override
@ -75,7 +76,9 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.answer_reached = False
@override
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
async def on_llm_end(
self, response: Union[LLMResult, AIMessage], **kwargs: Any
) -> None:
if self.answer_reached:
self.done.set()

View File

@ -2,11 +2,12 @@
import threading
from collections.abc import Sequence
from typing import Any, Optional
from typing import Any, Optional, Union
from uuid import UUID
from langchain_core.callbacks import base as base_callbacks
from langchain_core.documents import Document
from langchain_core.messages.v1 import AIMessage
from langchain_core.outputs import LLMResult
from typing_extensions import override
@ -111,7 +112,7 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
@override
def on_llm_end(
self,
response: LLMResult,
response: Union[LLMResult, AIMessage],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@ -6,6 +6,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.messages.v1 import MessageV1
from pydantic import BaseModel
from typing_extensions import override
@ -281,7 +282,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@ -6,6 +6,8 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from typing_extensions import override
@ -155,7 +157,7 @@ async def test_callback_handlers() -> None:
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -172,7 +174,9 @@ async def test_callback_handlers() -> None:
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,