From 8cf97e838c708e3454fcc592271c1f35ed16e613 Mon Sep 17 00:00:00 2001 From: ccurme Date: Tue, 29 Jul 2025 16:38:45 -0300 Subject: [PATCH] fix(core): lint standard outputs branch (#32311) --- libs/core/langchain_core/callbacks/base.py | 8 +- libs/core/langchain_core/callbacks/manager.py | 53 ++++-- .../callbacks/streaming_stdout.py | 7 +- libs/core/langchain_core/callbacks/usage.py | 14 +- .../langchain_core/language_models/_utils.py | 40 +++- .../langchain_core/language_models/base.py | 37 +++- .../language_models/chat_models.py | 28 --- .../language_models/v1/chat_models.py | 173 ++++++++++++++---- libs/core/langchain_core/messages/utils.py | 9 +- libs/core/langchain_core/messages/v1.py | 42 +---- .../output_parsers/openai_tools.py | 4 +- libs/core/langchain_core/tracers/base.py | 21 ++- libs/core/langchain_core/tracers/core.py | 31 +++- .../langchain_core/tracers/event_stream.py | 17 +- libs/core/langchain_core/tracers/langchain.py | 17 +- .../core/langchain_core/tracers/log_stream.py | 3 +- .../tests/benchmarks/test_async_callbacks.py | 8 +- libs/core/tests/unit_tests/fake/callbacks.py | 3 +- .../unit_tests/fake/test_fake_chat_model.py | 2 +- .../tracers/test_async_base_tracer.py | 8 +- .../unit_tests/tracers/test_base_tracer.py | 9 +- .../langchain/callbacks/streaming_aiter.py | 5 +- .../callbacks/streaming_aiter_final_only.py | 7 +- .../langchain/smith/evaluation/progress.py | 5 +- .../callbacks/fake_callback_handler.py | 3 +- .../unit_tests/llms/test_fake_chat_model.py | 8 +- 26 files changed, 384 insertions(+), 178 deletions(-) diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index f9c6f75f307..3ca427f84fc 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -66,7 +66,7 @@ class LLMManagerMixin: def on_llm_new_token( self, - token: Union[str, AIMessageChunk], + token: str, *, chunk: Optional[ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] @@ -265,7 +265,7 @@ class CallbackManagerMixin: def on_chat_model_start( self, serialized: dict[str, Any], - messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -516,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_chat_model_start( self, serialized: dict[str, Any], - messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -545,7 +545,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_new_token( self, - token: Union[str, AIMessageChunk], + token: str, *, chunk: Optional[ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index bca2d800696..dce9f088f76 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -11,7 +11,6 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager from contextvars import copy_context -from dataclasses import is_dataclass from typing import ( TYPE_CHECKING, Any, @@ -38,7 +37,13 @@ from langchain_core.callbacks.base import ( ) from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_core.messages.v1 import AIMessage, AIMessageChunk +from langchain_core.messages.utils import convert_from_v1_message +from langchain_core.messages.v1 import ( + AIMessage, + AIMessageChunk, + MessageV1, + MessageV1Types, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult from langchain_core.tracers.schemas import Run from langchain_core.utils.env import env_var_is_set @@ -249,17 +254,31 @@ def shielded(func: Func) -> Func: def _convert_llm_events( event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> None: - if event_name == "on_chat_model_start" and isinstance(args[1], list): - for idx, item in enumerate(args[1]): - if is_dataclass(item): - args[1][idx] = item # convert to old message - elif event_name == "on_llm_new_token" and is_dataclass(args[0]): - kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0]) - args[0] = args[0].text - elif event_name == "on_llm_end" and is_dataclass(args[0]): - args[0] = LLMResult( + if ( + event_name == "on_chat_model_start" + and isinstance(args[1], list) + and args[1] + and isinstance(args[1][0], MessageV1Types) + ): + batch = [ + convert_from_v1_message(item) + for item in args[1] + if isinstance(item, MessageV1Types) + ] + args[1] = [batch] # type: ignore[index] + elif ( + event_name == "on_llm_new_token" + and "chunk" in kwargs + and isinstance(kwargs["chunk"], MessageV1Types) + ): + chunk = kwargs["chunk"] + kwargs["chunk"] = ChatGenerationChunk(text=chunk.text, message=chunk) + elif event_name == "on_llm_end" and isinstance(args[0], MessageV1Types): + args[0] = LLMResult( # type: ignore[index] generations=[[ChatGeneration(text=args[0].text, message=args[0])]] ) + else: + return def handle_event( @@ -695,7 +714,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): def on_llm_new_token( self, - token: Union[str, AIMessageChunk], + token: str, *, chunk: Optional[ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] @@ -793,7 +812,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): async def on_llm_new_token( self, - token: Union[str, AIMessageChunk], + token: str, *, chunk: Optional[ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] @@ -1380,7 +1399,7 @@ class CallbackManager(BaseCallbackManager): def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], run_id: Optional[UUID] = None, **kwargs: Any, ) -> list[CallbackManagerForLLMRun]: @@ -1388,7 +1407,7 @@ class CallbackManager(BaseCallbackManager): Args: serialized (dict[str, Any]): The serialized LLM. - messages (list[list[BaseMessage]]): The list of messages. + messages (list[list[BaseMessage | MessageV1]]): The list of messages. run_id (UUID, optional): The ID of the run. Defaults to None. **kwargs (Any): Additional keyword arguments. @@ -1890,7 +1909,7 @@ class AsyncCallbackManager(BaseCallbackManager): async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], run_id: Optional[UUID] = None, **kwargs: Any, ) -> list[AsyncCallbackManagerForLLMRun]: @@ -1898,7 +1917,7 @@ class AsyncCallbackManager(BaseCallbackManager): Args: serialized (dict[str, Any]): The serialized LLM. - messages (list[list[BaseMessage]]): The list of messages. + messages (list[list[BaseMessage | MessageV1]]): The list of messages. run_id (UUID, optional): The ID of the run. Defaults to None. **kwargs (Any): Additional keyword arguments. diff --git a/libs/core/langchain_core/callbacks/streaming_stdout.py b/libs/core/langchain_core/callbacks/streaming_stdout.py index f8dbe518eac..f14d2962047 100644 --- a/libs/core/langchain_core/callbacks/streaming_stdout.py +++ b/libs/core/langchain_core/callbacks/streaming_stdout.py @@ -3,7 +3,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union from typing_extensions import override @@ -12,6 +12,7 @@ from langchain_core.callbacks.base import BaseCallbackHandler if TYPE_CHECKING: from langchain_core.agents import AgentAction, AgentFinish from langchain_core.messages import BaseMessage + from langchain_core.messages.v1 import AIMessage, MessageV1 from langchain_core.outputs import LLMResult @@ -32,7 +33,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], **kwargs: Any, ) -> None: """Run when LLM starts running. @@ -54,7 +55,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): sys.stdout.write(token) sys.stdout.flush() - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None: """Run when LLM ends running. Args: diff --git a/libs/core/langchain_core/callbacks/usage.py b/libs/core/langchain_core/callbacks/usage.py index 0249cadec1f..bc88d4b393b 100644 --- a/libs/core/langchain_core/callbacks/usage.py +++ b/libs/core/langchain_core/callbacks/usage.py @@ -4,13 +4,15 @@ import threading from collections.abc import Generator from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Optional +from typing import Any, Optional, Union from typing_extensions import override from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import AIMessage from langchain_core.messages.ai import UsageMetadata, add_usage +from langchain_core.messages.utils import convert_from_v1_message +from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.outputs import ChatGeneration, LLMResult @@ -58,9 +60,17 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler): return str(self.usage_metadata) @override - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + def on_llm_end( + self, response: Union[LLMResult, AIMessageV1], **kwargs: Any + ) -> None: """Collect token usage.""" # Check for usage_metadata (langchain-core >= 0.2.2) + if isinstance(response, AIMessageV1): + response = LLMResult( + generations=[ + [ChatGeneration(message=convert_from_v1_message(response))] + ] + ) try: generation = response.generations[0][0] except IndexError: diff --git a/libs/core/langchain_core/language_models/_utils.py b/libs/core/langchain_core/language_models/_utils.py index cc5d39a18c3..6d4b6c8475e 100644 --- a/libs/core/langchain_core/language_models/_utils.py +++ b/libs/core/langchain_core/language_models/_utils.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from typing import Optional from langchain_core.messages import BaseMessage +from langchain_core.messages.v1 import MessageV1 def _is_openai_data_block(block: dict) -> bool: @@ -129,10 +130,7 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]: and _is_openai_data_block(block) ): if formatted_message is message: - if isinstance(message, BaseMessage): - formatted_message = message.model_copy() - else: - formatted_message = copy.copy(message) + formatted_message = message.model_copy() # Also shallow-copy content formatted_message.content = list(formatted_message.content) @@ -142,3 +140,37 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]: formatted_messages.append(formatted_message) return formatted_messages + + +def _normalize_messages_v1(messages: Sequence[MessageV1]) -> list[MessageV1]: + """Extend support for message formats. + + Chat models implement support for images in OpenAI Chat Completions format, as well + as other multimodal data as standard data blocks. This function extends support to + audio and file data in OpenAI Chat Completions format by converting them to standard + data blocks. + """ + formatted_messages = [] + for message in messages: + formatted_message = message + if isinstance(message.content, list): + for idx, block in enumerate(message.content): + if ( + isinstance(block, dict) + # Subset to (PDF) files and audio, as most relevant chat models + # support images in OAI format (and some may not yet support the + # standard data block format) + and block.get("type") in {"file", "input_audio"} + and _is_openai_data_block(block) # type: ignore[arg-type] + ): + if formatted_message is message: + formatted_message = copy.copy(message) + # Also shallow-copy content + formatted_message.content = list(formatted_message.content) + + formatted_message.content[idx] = ( # type: ignore[call-overload] + _convert_openai_format_to_data_block(block) # type: ignore[arg-type] + ) + formatted_messages.append(formatted_message) + + return formatted_messages diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 8df8aeda0b8..1fef01a6859 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -2,7 +2,8 @@ from __future__ import annotations -from abc import ABC +import warnings +from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence from functools import cache from typing import ( @@ -25,6 +26,7 @@ from langchain_core.messages import ( AnyMessage, BaseMessage, MessageLikeRepresentation, + get_buffer_string, ) from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.prompt_values import PromptValue @@ -164,6 +166,7 @@ class BaseLanguageModel( list[AnyMessage], ] + @abstractmethod def generate_prompt( self, prompts: list[PromptValue], @@ -198,6 +201,7 @@ class BaseLanguageModel( prompt and additional model provider-specific output. """ + @abstractmethod async def agenerate_prompt( self, prompts: list[PromptValue], @@ -241,6 +245,7 @@ class BaseLanguageModel( raise NotImplementedError @deprecated("0.1.7", alternative="invoke", removal="1.0") + @abstractmethod def predict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: @@ -261,6 +266,7 @@ class BaseLanguageModel( """ @deprecated("0.1.7", alternative="invoke", removal="1.0") + @abstractmethod def predict_messages( self, messages: list[BaseMessage], @@ -285,6 +291,7 @@ class BaseLanguageModel( """ @deprecated("0.1.7", alternative="ainvoke", removal="1.0") + @abstractmethod async def apredict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: @@ -305,6 +312,7 @@ class BaseLanguageModel( """ @deprecated("0.1.7", alternative="ainvoke", removal="1.0") + @abstractmethod async def apredict_messages( self, messages: list[BaseMessage], @@ -360,6 +368,33 @@ class BaseLanguageModel( """ return len(self.get_token_ids(text)) + def get_num_tokens_from_messages( + self, + messages: list[BaseMessage], + tools: Optional[Sequence] = None, + ) -> int: + """Get the number of tokens in the messages. + + Useful for checking if an input fits in a model's context window. + + **Note**: the base implementation of get_num_tokens_from_messages ignores + tool schemas. + + Args: + messages: The message inputs to tokenize. + tools: If provided, sequence of dict, BaseModel, function, or BaseTools + to be converted to tool schemas. + + Returns: + The sum of the number of tokens across the messages. + """ + if tools is not None: + warnings.warn( + "Counting tokens in tool schemas is not yet supported. Ignoring tools.", + stacklevel=2, + ) + return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages) + @classmethod def _all_required_field_names(cls) -> set: """DEPRECATED: Kept for backwards compatibility. diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 0c631e8011a..f1a7c4726a9 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -55,7 +55,6 @@ from langchain_core.messages import ( HumanMessage, convert_to_messages, convert_to_openai_image_block, - get_buffer_string, is_data_content_block, message_chunk_to_message, ) @@ -1352,33 +1351,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): starter_dict["_type"] = self._llm_type return starter_dict - def get_num_tokens_from_messages( - self, - messages: list[BaseMessage], - tools: Optional[Sequence] = None, - ) -> int: - """Get the number of tokens in the messages. - - Useful for checking if an input fits in a model's context window. - - **Note**: the base implementation of get_num_tokens_from_messages ignores - tool schemas. - - Args: - messages: The message inputs to tokenize. - tools: If provided, sequence of dict, BaseModel, function, or BaseTools - to be converted to tool schemas. - - Returns: - The sum of the number of tokens across the messages. - """ - if tools is not None: - warnings.warn( - "Counting tokens in tool schemas is not yet supported. Ignoring tools.", - stacklevel=2, - ) - return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages) - def bind_tools( self, tools: Sequence[ diff --git a/libs/core/langchain_core/language_models/v1/chat_models.py b/libs/core/langchain_core/language_models/v1/chat_models.py index bc39787c682..55e59314266 100644 --- a/libs/core/langchain_core/language_models/v1/chat_models.py +++ b/libs/core/langchain_core/language_models/v1/chat_models.py @@ -6,7 +6,7 @@ import copy import typing import warnings from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from operator import itemgetter from typing import ( TYPE_CHECKING, @@ -22,29 +22,32 @@ from pydantic import ( BaseModel, ConfigDict, Field, + field_validator, ) -from typing_extensions import override +from typing_extensions import TypeAlias, override +from langchain_core.caches import BaseCache from langchain_core.callbacks import ( AsyncCallbackManager, AsyncCallbackManagerForLLMRun, CallbackManager, CallbackManagerForLLMRun, + Callbacks, ) -from langchain_core.language_models._utils import _normalize_messages +from langchain_core.language_models._utils import _normalize_messages_v1 from langchain_core.language_models.base import ( - BaseLanguageModel, LangSmithParams, LanguageModelInput, + _get_token_ids_default_method, + _get_verbosity, ) from langchain_core.messages import ( - AIMessage, convert_to_openai_image_block, get_buffer_string, is_data_content_block, ) from langchain_core.messages.utils import ( - _convert_from_v1_message, + convert_from_v1_message, convert_to_messages_v1, ) from langchain_core.messages.v1 import AIMessage as AIMessageV1 @@ -58,6 +61,7 @@ from langchain_core.outputs import ( from langchain_core.prompt_values import PromptValue from langchain_core.rate_limiters import BaseRateLimiter from langchain_core.runnables import RunnableMap, RunnablePassthrough +from langchain_core.runnables.base import RunnableSerializable from langchain_core.runnables.config import ensure_config, run_in_executor from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.utils.function_calling import ( @@ -85,14 +89,15 @@ def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]: metadata["status_code"] = response.status_code if hasattr(error, "request_id"): metadata["request_id"] = error.request_id - generations = [AIMessageV1(content=[], response_metadata=metadata)] + # Permit response_metadata without model_name, model_provider fields + generations = [AIMessageV1(content=[], response_metadata=metadata)] # type: ignore[arg-type] else: generations = [] return generations -def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]: +def _format_for_tracing(messages: Sequence[MessageV1]) -> list[MessageV1]: """Format messages for tracing in on_chat_model_start. - Update image content blocks to OpenAI Chat Completions format (backward @@ -112,7 +117,7 @@ def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]: # Update image content blocks to OpenAI # Chat Completions format. if ( block["type"] == "image" - and is_data_content_block(block) + and is_data_content_block(block) # type: ignore[arg-type] # permit unnecessary runtime check and block.get("source_type") != "id" ): if message_to_trace is message: @@ -120,7 +125,9 @@ def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]: message_to_trace = copy.copy(message) message_to_trace.content = list(message_to_trace.content) - message_to_trace.content[idx] = convert_to_openai_image_block(block) + # TODO: for tracing purposes we store non-standard types (OpenAI format) + # in message content. Consider typing these block formats. + message_to_trace.content[idx] = convert_to_openai_image_block(block) # type: ignore[arg-type, call-overload] else: pass messages_to_trace.append(message_to_trace) @@ -180,7 +187,7 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) -> return ls_structured_output_format_dict -class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): +class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC): """Base class for chat models. Key imperative methods: @@ -278,12 +285,69 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): does not properly support streaming. """ + cache: Union[BaseCache, bool, None] = Field(default=None, exclude=True) + """Whether to cache the response. + + * If true, will use the global cache. + * If false, will not use a cache + * If None, will use the global cache if it's set, otherwise no cache. + * If instance of BaseCache, will use the provided cache. + + Caching is not currently supported for streaming methods of models. + """ + verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False) + """Whether to print out response text.""" + callbacks: Callbacks = Field(default=None, exclude=True) + """Callbacks to add to the run trace.""" + tags: Optional[list[str]] = Field(default=None, exclude=True) + """Tags to add to the run trace.""" + metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True) + """Metadata to add to the run trace.""" + custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field( + default=None, exclude=True + ) + """Optional encoder to use for counting tokens.""" + model_config = ConfigDict( arbitrary_types_allowed=True, ) # --- Runnable methods --- + @field_validator("verbose", mode="before") + def set_verbose(cls, verbose: Optional[bool]) -> bool: # noqa: FBT001 + """If verbose is None, set it. + + This allows users to pass in None as verbose to access the global setting. + + Args: + verbose: The verbosity setting to use. + + Returns: + The verbosity setting to use. + """ + if verbose is None: + return _get_verbosity() + return verbose + + @property + @override + def InputType(self) -> TypeAlias: + """Get the input type for this runnable.""" + from langchain_core.prompt_values import ( + ChatPromptValueConcrete, + StringPromptValue, + ) + + # This is a version of LanguageModelInput which replaces the abstract + # base class BaseMessage with a union of its subclasses, which makes + # for a much better schema. + return Union[ + str, + Union[StringPromptValue, ChatPromptValueConcrete], + list[MessageV1], + ] + @property @override def OutputType(self) -> Any: @@ -299,7 +363,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): return convert_to_messages_v1(model_input) msg = ( f"Invalid input type {type(model_input)}. " - "Must be a PromptValue, str, or list of BaseMessages." + "Must be a PromptValue, str, or list of Messages." ) raise ValueError(msg) @@ -371,7 +435,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): ) (run_manager,) = callback_manager.on_chat_model_start( {}, - [_format_for_tracing(messages)], + _format_for_tracing(messages), invocation_params=params, options=options, name=config.get("run_name"), @@ -382,27 +446,27 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): if self.rate_limiter: self.rate_limiter.acquire(blocking=True) - input_messages = _normalize_messages(messages) + input_messages = _normalize_messages_v1(messages) if self._should_stream(async_api=False, **kwargs): chunks: list[AIMessageChunkV1] = [] try: for msg in self._stream(input_messages, **kwargs): - run_manager.on_llm_new_token(msg) + run_manager.on_llm_new_token(msg.text or "") chunks.append(msg) except BaseException as e: run_manager.on_llm_error(e, response=_generate_response_from_error(e)) raise - msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + full_message = add_ai_message_chunks(chunks[0], *chunks[1:]).to_message() else: try: - msg = self._invoke(input_messages, **kwargs) + full_message = self._invoke(input_messages, **kwargs) except BaseException as e: run_manager.on_llm_error(e) raise - run_manager.on_llm_end(msg) - return msg + run_manager.on_llm_end(full_message) + return full_message @override async def ainvoke( @@ -410,7 +474,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> AIMessage: + ) -> AIMessageV1: config = ensure_config(config) messages = self._convert_input(input) ls_structured_output_format = kwargs.pop( @@ -437,7 +501,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): ) (run_manager,) = await callback_manager.on_chat_model_start( {}, - [_format_for_tracing(messages)], + _format_for_tracing(messages), invocation_params=params, options=options, name=config.get("run_name"), @@ -449,31 +513,31 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): await self.rate_limiter.aacquire(blocking=True) # TODO: type openai image, audio, file types and permit in MessageV1 - input_messages = _normalize_messages(messages) # type: ignore[arg-type] + input_messages = _normalize_messages_v1(messages) if self._should_stream(async_api=True, **kwargs): chunks: list[AIMessageChunkV1] = [] try: async for msg in self._astream(input_messages, **kwargs): - await run_manager.on_llm_new_token(msg) + await run_manager.on_llm_new_token(msg.text or "") chunks.append(msg) except BaseException as e: await run_manager.on_llm_error( e, response=_generate_response_from_error(e) ) raise - msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + full_message = add_ai_message_chunks(chunks[0], *chunks[1:]).to_message() else: try: - msg = await self._ainvoke(input_messages, **kwargs) + full_message = await self._ainvoke(input_messages, **kwargs) except BaseException as e: await run_manager.on_llm_error( e, response=_generate_response_from_error(e) ) raise - await run_manager.on_llm_end(msg.to_message()) - return msg + await run_manager.on_llm_end(full_message) + return full_message @override def stream( @@ -515,7 +579,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): ) (run_manager,) = callback_manager.on_chat_model_start( {}, - [_format_for_tracing(messages)], + _format_for_tracing(messages), invocation_params=params, options=options, name=config.get("run_name"), @@ -530,9 +594,9 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): try: # TODO: replace this with something for new messages - input_messages = _normalize_messages(messages) + input_messages = _normalize_messages_v1(messages) for msg in self._stream(input_messages, **kwargs): - run_manager.on_llm_new_token(msg) + run_manager.on_llm_new_token(msg.text or "") chunks.append(msg) yield msg except BaseException as e: @@ -584,7 +648,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): ) (run_manager,) = await callback_manager.on_chat_model_start( {}, - [_format_for_tracing(messages)], + _format_for_tracing(messages), invocation_params=params, options=options, name=config.get("run_name"), @@ -598,12 +662,12 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): chunks: list[AIMessageChunkV1] = [] try: - input_messages = _normalize_messages(messages) + input_messages = _normalize_messages_v1(messages) async for msg in self._astream( input_messages, **kwargs, ): - await run_manager.on_llm_new_token(msg) + await run_manager.on_llm_new_token(msg.text or "") chunks.append(msg) yield msg except BaseException as e: @@ -623,7 +687,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): stop: Optional[list[str]] = None, **kwargs: Any, ) -> dict: - params = self.dict() + params = self.dump() params["stop"] = stop return {**params, **kwargs} @@ -674,14 +738,14 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): self, messages: list[MessageV1], **kwargs: Any, - ) -> AIMessage: + ) -> AIMessageV1: raise NotImplementedError async def _ainvoke( self, messages: list[MessageV1], **kwargs: Any, - ) -> AIMessage: + ) -> AIMessageV1: return await run_in_executor( None, self._invoke, @@ -724,8 +788,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): def _llm_type(self) -> str: """Return type of chat model.""" - @override - def dict(self, **kwargs: Any) -> dict: + def dump(self, **kwargs: Any) -> dict: # noqa: ARG002 """Return a dictionary of the LLM.""" starter_dict = dict(self._identifying_params) starter_dict["_type"] = self._llm_type @@ -903,6 +966,38 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): return RunnableMap(raw=llm) | parser_with_fallback return llm | output_parser + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return self.lc_attributes + + def get_token_ids(self, text: str) -> list[int]: + """Return the ordered ids of the tokens in a text. + + Args: + text: The string input to tokenize. + + Returns: + A list of ids corresponding to the tokens in the text, in order they occur + in the text. + """ + if self.custom_get_token_ids is not None: + return self.custom_get_token_ids(text) + return _get_token_ids_default_method(text) + + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text. + + Useful for checking if an input fits in a model's context window. + + Args: + text: The string input to tokenize. + + Returns: + The integer number of tokens in the text. + """ + return len(self.get_token_ids(text)) + def get_num_tokens_from_messages( self, messages: list[MessageV1], @@ -923,7 +1018,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): Returns: The sum of the number of tokens across the messages. """ - messages_v0 = [_convert_from_v1_message(message) for message in messages] + messages_v0 = [convert_from_v1_message(message) for message in messages] if tools is not None: warnings.warn( "Counting tokens in tool schemas is not yet supported. Ignoring tools.", diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 7df908d3142..affb0b15a53 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -327,7 +327,6 @@ def _create_message_from_message_type_v1( ValueError: if the message type is not one of "human", "user", "ai", "assistant", "tool", "system", or "developer". """ - kwargs: dict[str, Any] = {} if name is not None: kwargs["name"] = name if tool_call_id is not None: @@ -355,7 +354,7 @@ def _create_message_from_message_type_v1( else: kwargs["tool_calls"].append(tool_call) if message_type in {"human", "user"}: - message = HumanMessageV1(content=content, **kwargs) + message: MessageV1 = HumanMessageV1(content=content, **kwargs) elif message_type in {"ai", "assistant"}: message = AIMessageV1(content=content, **kwargs) elif message_type in {"system", "developer"}: @@ -375,7 +374,7 @@ def _create_message_from_message_type_v1( return message -def _convert_from_v1_message(message: MessageV1) -> BaseMessage: +def convert_from_v1_message(message: MessageV1) -> BaseMessage: """Compatibility layer to convert v1 messages to current messages. Args: @@ -469,7 +468,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: msg_type, msg_content, **msg_kwargs ) elif isinstance(message, MessageV1Types): - message_ = _convert_from_v1_message(message) + message_ = convert_from_v1_message(message) else: msg = f"Unsupported message type: {type(message)}" msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE) @@ -501,7 +500,7 @@ def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1: """ if isinstance(message, MessageV1Types): if isinstance(message, AIMessageChunkV1): - message_ = message.to_message() + message_: MessageV1 = message.to_message() else: message_ = message elif isinstance(message, str): diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py index 4726becc1e6..313fd12ae7b 100644 --- a/libs/core/langchain_core/messages/v1.py +++ b/libs/core/langchain_core/messages/v1.py @@ -188,7 +188,7 @@ class AIMessage: @dataclass -class AIMessageChunk: +class AIMessageChunk(AIMessage): """A partial chunk of an AI message during streaming. Represents a portion of an AI response that is delivered incrementally @@ -203,43 +203,7 @@ class AIMessageChunk: usage_metadata: Optional metadata about token usage and costs. """ - type: Literal["ai_chunk"] = "ai_chunk" - - name: Optional[str] = None - """An optional name for the message. - - This can be used to provide a human-readable name for the message. - - Usage of this field is optional, and whether it's used or not is up to the - model implementation. - """ - - id: Optional[str] = None - """Unique identifier for the message chunk. - - If the provider assigns a meaningful ID, it should be used here. - """ - - lc_version: str = "v1" - """Encoding version for the message.""" - - content: list[types.ContentBlock] = field(init=False) - - usage_metadata: Optional[UsageMetadata] = None - """If provided, usage metadata for a message chunk, such as token counts. - - These data represent incremental usage statistics, as opposed to a running total. - """ - - response_metadata: ResponseMetadata = field(init=False) - """Metadata about the response chunk. - - This field should include non-standard data returned by the provider, such as - response headers, service tiers, or log probabilities. - """ - - parsed: Optional[Union[dict[str, Any], BaseModel]] = None - """Auto-parsed message contents, if applicable.""" + type: Literal["ai_chunk"] = "ai_chunk" # type: ignore[assignment] tool_call_chunks: list[types.ToolCallChunk] = field(init=False) @@ -342,7 +306,7 @@ class AIMessageChunk: if isinstance(args_, dict): tool_calls.append( create_tool_call( - name=chunk.get("name", ""), + name=chunk.get("name") or "", args=args_, id=chunk.get("id", ""), ) diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index e01f919606e..63495bc2d84 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -9,7 +9,7 @@ from typing import Annotated, Any, Optional from pydantic import SkipValidation, ValidationError from langchain_core.exceptions import OutputParserException -from langchain_core.messages import AIMessage, InvalidToolCall, ToolCall +from langchain_core.messages import AIMessage, InvalidToolCall from langchain_core.messages.tool import invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser @@ -26,7 +26,7 @@ def parse_tool_call( partial: bool = False, strict: bool = False, return_id: bool = True, -) -> Optional[ToolCall]: +) -> Optional[dict[str, Any]]: """Parse a single tool call. Args: diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index ee588606165..13fbe0e90ce 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -16,6 +16,7 @@ from typing_extensions import override from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.exceptions import TracerException # noqa: F401 +from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1 from langchain_core.tracers.core import _TracerCore if TYPE_CHECKING: @@ -54,7 +55,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, tags: Optional[list[str]] = None, @@ -138,7 +139,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, @@ -190,7 +193,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): ) @override - def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: + def on_llm_end( + self, response: Union[LLMResult, AIMessage], *, run_id: UUID, **kwargs: Any + ) -> Run: """End a trace for an LLM run. Args: @@ -562,7 +567,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -617,7 +622,9 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, @@ -646,7 +653,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): @override async def on_llm_end( self, - response: LLMResult, + response: Union[LLMResult, AIMessage], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -882,7 +889,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): self, run: Run, token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]], ) -> None: """Process new LLM token.""" diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index 0a10b06ecaf..cdef8ddcb7b 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -18,6 +18,13 @@ from typing import ( from langchain_core.exceptions import TracerException from langchain_core.load import dumpd +from langchain_core.messages.utils import convert_from_v1_message +from langchain_core.messages.v1 import ( + AIMessage, + AIMessageChunk, + MessageV1, + MessageV1Types, +) from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, @@ -156,7 +163,7 @@ class _TracerCore(ABC): def _create_chat_model_run( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], run_id: UUID, tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, @@ -181,6 +188,12 @@ class _TracerCore(ABC): start_time = datetime.now(timezone.utc) if metadata: kwargs.update({"metadata": metadata}) + if isinstance(messages[0], MessageV1Types): + # Convert from v1 messages to BaseMessage + messages = [ + [convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type] + ] + messages = cast("list[list[BaseMessage]]", messages) return Run( id=run_id, parent_run_id=parent_run_id, @@ -230,7 +243,9 @@ class _TracerCore(ABC): self, token: str, run_id: UUID, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, parent_run_id: Optional[UUID] = None, # noqa: ARG002 ) -> Run: """Append token event to LLM run and return the run.""" @@ -276,7 +291,15 @@ class _TracerCore(ABC): ) return llm_run - def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run: + def _complete_llm_run( + self, response: Union[LLMResult, AIMessage], run_id: UUID + ) -> Run: + if isinstance(response, AIMessage): + response = LLMResult( + generations=[ + [ChatGeneration(message=convert_from_v1_message(response))] + ] + ) llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) if getattr(llm_run, "outputs", None) is None: llm_run.outputs = {} @@ -558,7 +581,7 @@ class _TracerCore(ABC): self, run: Run, # noqa: ARG002 token: str, # noqa: ARG002 - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002 + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]], # noqa: ARG002 ) -> Union[None, Coroutine[Any, Any, None]]: """Process new LLM token.""" return None diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index e510356e6b6..57289431935 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -19,6 +19,7 @@ from typing_extensions import NotRequired, TypedDict, override from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk +from langchain_core.messages.v1 import MessageV1 from langchain_core.outputs import ( ChatGenerationChunk, GenerationChunk, @@ -43,6 +44,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterator, Sequence from langchain_core.documents import Document + from langchain_core.messages.v1 import AIMessage as AIMessageV1 + from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tracers.log_stream import LogEntry @@ -297,7 +300,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, tags: Optional[list[str]] = None, @@ -307,6 +310,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" + # below cast is because type is converted in handle_event + messages = cast("list[list[BaseMessage]]", messages) name_ = _assign_name(name, serialized) run_type = "chat_model" @@ -407,13 +412,18 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run on new LLM token. Only available when streaming is enabled.""" run_info = self.run_map.get(run_id) + chunk = cast( + "Optional[Union[GenerationChunk, ChatGenerationChunk]]", chunk + ) # converted in handle_event chunk_: Union[GenerationChunk, BaseMessageChunk] if run_info is None: @@ -456,9 +466,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand @override async def on_llm_end( - self, response: LLMResult, *, run_id: UUID, **kwargs: Any + self, response: Union[LLMResult, AIMessageV1], *, run_id: UUID, **kwargs: Any ) -> None: """End a trace for an LLM run.""" + response = cast("LLMResult", response) # converted in handle_event run_info = self.run_map.pop(run_id) inputs_ = run_info["inputs"] diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 4a6d0d82344..dbd8a63a2b7 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast from uuid import UUID from langsmith import Client @@ -21,11 +21,14 @@ from typing_extensions import override from langchain_core.env import get_runtime_environment from langchain_core.load import dumpd +from langchain_core.messages.utils import convert_from_v1_message +from langchain_core.messages.v1 import MessageV1Types from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.schemas import Run if TYPE_CHECKING: from langchain_core.messages import BaseMessage + from langchain_core.messages.v1 import AIMessageChunk, MessageV1 from langchain_core.outputs import ChatGenerationChunk, GenerationChunk logger = logging.getLogger(__name__) @@ -113,7 +116,7 @@ class LangChainTracer(BaseTracer): def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, tags: Optional[list[str]] = None, @@ -140,6 +143,12 @@ class LangChainTracer(BaseTracer): start_time = datetime.now(timezone.utc) if metadata: kwargs.update({"metadata": metadata}) + if isinstance(messages[0], MessageV1Types): + # Convert from v1 messages to BaseMessage + messages = [ + [convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type] + ] + messages = cast("list[list[BaseMessage]]", messages) chat_model_run = Run( id=run_id, parent_run_id=parent_run_id, @@ -232,7 +241,9 @@ class LangChainTracer(BaseTracer): self, token: str, run_id: UUID, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, parent_run_id: Optional[UUID] = None, ) -> Run: """Append token event to LLM run and return the run.""" diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 5246ea8456e..0ce4b788972 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterator, Sequence from uuid import UUID + from langchain_core.messages.v1 import AIMessageChunk from langchain_core.runnables.utils import Input, Output from langchain_core.tracers.schemas import Run @@ -485,7 +486,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): self, run: Run, token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]], ) -> None: """Process new LLM token.""" index = self._key_map_by_run_id.get(run.id) diff --git a/libs/core/tests/benchmarks/test_async_callbacks.py b/libs/core/tests/benchmarks/test_async_callbacks.py index 5cb58f0210e..c29598c3aa7 100644 --- a/libs/core/tests/benchmarks/test_async_callbacks.py +++ b/libs/core/tests/benchmarks/test_async_callbacks.py @@ -10,6 +10,8 @@ from typing_extensions import override from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.language_models import GenericFakeChatModel from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import MessageV1 from langchain_core.outputs import ChatGenerationChunk, GenerationChunk @@ -18,7 +20,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler): async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -35,7 +37,9 @@ class MyCustomAsyncHandler(AsyncCallbackHandler): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[list[str]] = None, diff --git a/libs/core/tests/unit_tests/fake/callbacks.py b/libs/core/tests/unit_tests/fake/callbacks.py index b8ec1778b42..30a79601f07 100644 --- a/libs/core/tests/unit_tests/fake/callbacks.py +++ b/libs/core/tests/unit_tests/fake/callbacks.py @@ -9,6 +9,7 @@ from typing_extensions import override from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.messages import BaseMessage +from langchain_core.messages.v1 import MessageV1 class BaseFakeCallbackHandler(BaseModel): @@ -285,7 +286,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 4d17f3b8cae..40ed46f0745 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -165,7 +165,7 @@ async def test_callback_handlers() -> None: async def on_chat_model_start( self, serialized: dict[str, Any], - messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, diff --git a/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py index 1b243c03816..318174ecc95 100644 --- a/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any +from typing import TYPE_CHECKING, Any from uuid import uuid4 import pytest @@ -16,6 +16,9 @@ from langchain_core.outputs import LLMResult from langchain_core.tracers.base import AsyncBaseTracer from langchain_core.tracers.schemas import Run +if TYPE_CHECKING: + from langchain_core.messages import BaseMessage + SERIALIZED = {"id": ["llm"]} SERIALIZED_CHAT = {"id": ["chat_model"]} @@ -84,8 +87,9 @@ async def test_tracer_chat_model_run() -> None: """Test tracer on a Chat Model run.""" tracer = FakeAsyncTracer() manager = AsyncCallbackManager(handlers=[tracer]) + messages: list[list[BaseMessage]] = [[HumanMessage(content="")]] run_managers = await manager.on_chat_model_start( - serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]] + serialized=SERIALIZED_CHAT, messages=messages ) compare_run = Run( id=str(run_managers[0].run_id), # type: ignore[arg-type] diff --git a/libs/core/tests/unit_tests/tracers/test_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_base_tracer.py index aaa34a662f2..f80cdb3b679 100644 --- a/libs/core/tests/unit_tests/tracers/test_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_base_tracer.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any +from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock from uuid import uuid4 @@ -20,6 +20,9 @@ from langchain_core.runnables import chain as as_runnable from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.schemas import Run +if TYPE_CHECKING: + from langchain_core.messages import BaseMessage + SERIALIZED = {"id": ["llm"]} SERIALIZED_CHAT = {"id": ["chat_model"]} @@ -89,8 +92,10 @@ def test_tracer_chat_model_run() -> None: """Test tracer on a Chat Model run.""" tracer = FakeTracer() manager = CallbackManager(handlers=[tracer]) + # TODO: why is this annotation needed + messages: list[list[BaseMessage]] = [[HumanMessage(content="")]] run_managers = manager.on_chat_model_start( - serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]] + serialized=SERIALIZED_CHAT, messages=messages ) compare_run = Run( id=str(run_managers[0].run_id), # type: ignore[arg-type] diff --git a/libs/langchain/langchain/callbacks/streaming_aiter.py b/libs/langchain/langchain/callbacks/streaming_aiter.py index 96cf78fd83a..2cf57d1f409 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter.py @@ -5,6 +5,7 @@ from collections.abc import AsyncIterator from typing import Any, Literal, Union, cast from langchain_core.callbacks import AsyncCallbackHandler +from langchain_core.messages.v1 import AIMessage from langchain_core.outputs import LLMResult from typing_extensions import override @@ -44,7 +45,9 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler): self.queue.put_nowait(token) @override - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + async def on_llm_end( + self, response: Union[LLMResult, AIMessage], **kwargs: Any + ) -> None: self.done.set() @override diff --git a/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py b/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py index dbf125bd175..1c37a4f43cf 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, Union +from langchain_core.messages.v1 import AIMessage from langchain_core.outputs import LLMResult from typing_extensions import override @@ -75,7 +76,9 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.answer_reached = False @override - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + async def on_llm_end( + self, response: Union[LLMResult, AIMessage], **kwargs: Any + ) -> None: if self.answer_reached: self.done.set() diff --git a/libs/langchain/langchain/smith/evaluation/progress.py b/libs/langchain/langchain/smith/evaluation/progress.py index 4282f9c76ae..ba4e14cc573 100644 --- a/libs/langchain/langchain/smith/evaluation/progress.py +++ b/libs/langchain/langchain/smith/evaluation/progress.py @@ -2,11 +2,12 @@ import threading from collections.abc import Sequence -from typing import Any, Optional +from typing import Any, Optional, Union from uuid import UUID from langchain_core.callbacks import base as base_callbacks from langchain_core.documents import Document +from langchain_core.messages.v1 import AIMessage from langchain_core.outputs import LLMResult from typing_extensions import override @@ -111,7 +112,7 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler): @override def on_llm_end( self, - response: LLMResult, + response: Union[LLMResult, AIMessage], *, run_id: UUID, parent_run_id: Optional[UUID] = None, diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index 8b97e226ec7..cad3d1f9cfe 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -6,6 +6,7 @@ from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.messages import BaseMessage +from langchain_core.messages.v1 import MessageV1 from pydantic import BaseModel from typing_extensions import override @@ -281,7 +282,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index e5e8de87f0f..561ac0a29a5 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -6,6 +6,8 @@ from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import MessageV1 from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from typing_extensions import override @@ -155,7 +157,7 @@ async def test_callback_handlers() -> None: async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[MessageV1]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -172,7 +174,9 @@ async def test_callback_handlers() -> None: self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[list[str]] = None,