From 3d9e694f73da705c349d9e0952c56ac4cc8cbe07 Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 28 Jul 2025 11:17:06 -0300 Subject: [PATCH] feat(core): start on v1 chat model (#32276) Co-authored-by: Nuno Campos --- libs/core/langchain_core/callbacks/base.py | 41 +- libs/core/langchain_core/callbacks/manager.py | 58 +- .../langchain_core/language_models/_utils.py | 6 +- .../langchain_core/language_models/base.py | 5 +- .../language_models/v1/__init__.py | 1 + .../language_models/v1/chat_models.py | 909 ++++++++++++++++++ libs/core/langchain_core/messages/utils.py | 149 +++ libs/core/langchain_core/messages/v1.py | 2 + .../unit_tests/fake/test_fake_chat_model.py | 14 +- 9 files changed, 1146 insertions(+), 39 deletions(-) create mode 100644 libs/core/langchain_core/language_models/v1/__init__.py create mode 100644 libs/core/langchain_core/language_models/v1/chat_models.py diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 5365fcb9ef1..f9c6f75f307 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union from typing_extensions import Self +from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1 + if TYPE_CHECKING: from collections.abc import Sequence from uuid import UUID @@ -64,9 +66,11 @@ class LLMManagerMixin: def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunk], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, @@ -75,8 +79,8 @@ class LLMManagerMixin: Args: token (str): The new token. - chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, - containing content and other information. + chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new + generated chunk, containing content and other information. run_id (UUID): The run ID. This is the ID of the current run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run. kwargs (Any): Additional keyword arguments. @@ -84,7 +88,7 @@ class LLMManagerMixin: def on_llm_end( self, - response: LLMResult, + response: Union[LLMResult, AIMessage], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -93,7 +97,7 @@ class LLMManagerMixin: """Run when LLM ends running. Args: - response (LLMResult): The response which was generated. + response (LLMResult | AIMessage): The response which was generated. run_id (UUID): The run ID. This is the ID of the current run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run. kwargs (Any): Additional keyword arguments. @@ -261,7 +265,7 @@ class CallbackManagerMixin: def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -439,6 +443,9 @@ class BaseCallbackHandler( run_inline: bool = False """Whether to run the callback inline.""" + accepts_new_messages: bool = False + """Whether the callback accepts new message format.""" + @property def ignore_llm(self) -> bool: """Whether to ignore LLM callbacks.""" @@ -509,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -538,9 +545,11 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunk], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[list[str]] = None, @@ -550,8 +559,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): Args: token (str): The new token. - chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, - containing content and other information. + chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new + generated chunk, containing content and other information. run_id (UUID): The run ID. This is the ID of the current run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run. tags (Optional[list[str]]): The tags. @@ -560,7 +569,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_end( self, - response: LLMResult, + response: Union[LLMResult, AIMessage], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -570,7 +579,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): """Run when LLM ends running. Args: - response (LLMResult): The response which was generated. + response (LLMResult | AIMessage): The response which was generated. run_id (UUID): The run ID. This is the ID of the current run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run. tags (Optional[list[str]]): The tags. @@ -594,8 +603,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): parent_run_id: The parent run ID. This is the ID of the parent run. tags: The tags. kwargs (Any): Additional keyword arguments. - - response (LLMResult): The response which was generated before - the error occurred. + - response (LLMResult | AIMessage): The response which was generated + before the error occurred. """ async def on_chain_start( diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 3b4f1f6e30e..d07da15ee13 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -11,6 +11,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager from contextvars import copy_context +from dataclasses import is_dataclass from typing import ( TYPE_CHECKING, Any, @@ -37,6 +38,8 @@ from langchain_core.callbacks.base import ( ) from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.messages.v1 import AIMessage, AIMessageChunk +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult from langchain_core.tracers.schemas import Run from langchain_core.utils.env import env_var_is_set @@ -47,7 +50,7 @@ if TYPE_CHECKING: from langchain_core.agents import AgentAction, AgentFinish from langchain_core.documents import Document - from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult + from langchain_core.outputs import GenerationChunk from langchain_core.runnables.config import RunnableConfig logger = logging.getLogger(__name__) @@ -241,6 +244,22 @@ def shielded(func: Func) -> Func: return cast("Func", wrapped) +def _convert_llm_events( + event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> None: + if event_name == "on_chat_model_start" and isinstance(args[1], list): + for idx, item in enumerate(args[1]): + if is_dataclass(item): + args[1][idx] = item # convert to old message + elif event_name == "on_llm_new_token" and is_dataclass(args[0]): + kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0]) + args[0] = args[0].text + elif event_name == "on_llm_end" and is_dataclass(args[0]): + args[0] = LLMResult( + generations=[[ChatGeneration(text=args[0].text, message=args[0])]] + ) + + def handle_event( handlers: list[BaseCallbackHandler], event_name: str, @@ -269,6 +288,8 @@ def handle_event( if ignore_condition_name is None or not getattr( handler, ignore_condition_name ): + if not handler.accepts_new_messages: + _convert_llm_events(event_name, args, kwargs) event = getattr(handler, event_name)(*args, **kwargs) if asyncio.iscoroutine(event): coros.append(event) @@ -363,6 +384,8 @@ async def _ahandle_event_for_handler( ) -> None: try: if ignore_condition_name is None or not getattr(handler, ignore_condition_name): + if not handler.accepts_new_messages: + _convert_llm_events(event_name, args, kwargs) event = getattr(handler, event_name) if asyncio.iscoroutinefunction(event): await event(*args, **kwargs) @@ -670,9 +693,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunk], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token. @@ -697,11 +722,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): **kwargs, ) - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None: """Run when LLM ends running. Args: - response (LLMResult): The LLM result. + response (LLMResult | AIMessage): The LLM result. **kwargs (Any): Additional keyword arguments. """ if not self.handlers: @@ -727,8 +752,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. kwargs (Any): Additional keyword arguments. - - response (LLMResult): The response which was generated before - the error occurred. + - response (LLMResult | AIMessage): The response which was generated + before the error occurred. """ if not self.handlers: return @@ -766,9 +791,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): async def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunk], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token. @@ -794,11 +821,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): ) @shielded - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + async def on_llm_end( + self, response: Union[LLMResult, AIMessage], **kwargs: Any + ) -> None: """Run when LLM ends running. Args: - response (LLMResult): The LLM result. + response (LLMResult | AIMessage): The LLM result. **kwargs (Any): Additional keyword arguments. """ if not self.handlers: @@ -825,11 +854,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. kwargs (Any): Additional keyword arguments. - - response (LLMResult): The response which was generated before - the error occurred. - - - + - response (LLMResult | AIMessage): The response which was generated + before the error occurred. """ if not self.handlers: return diff --git a/libs/core/langchain_core/language_models/_utils.py b/libs/core/langchain_core/language_models/_utils.py index 3b7c0c5debe..3d7211669fc 100644 --- a/libs/core/langchain_core/language_models/_utils.py +++ b/libs/core/langchain_core/language_models/_utils.py @@ -1,3 +1,4 @@ +import copy import re from collections.abc import Sequence from typing import Optional @@ -127,7 +128,10 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]: and _is_openai_data_block(block) ): if formatted_message is message: - formatted_message = message.model_copy() + if isinstance(message, BaseMessage): + formatted_message = message.model_copy() + else: + formatted_message = copy.copy(message) # Also shallow-copy content formatted_message.content = list(formatted_message.content) diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index a9e7e4e64cc..1fef01a6859 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -28,6 +28,7 @@ from langchain_core.messages import ( MessageLikeRepresentation, get_buffer_string, ) +from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.prompt_values import PromptValue from langchain_core.runnables import Runnable, RunnableSerializable from langchain_core.utils import get_pydantic_field_names @@ -85,7 +86,9 @@ def _get_token_ids_default_method(text: str) -> list[int]: LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]] LanguageModelOutput = Union[BaseMessage, str] LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput] -LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str) +LanguageModelOutputVar = TypeVar( + "LanguageModelOutputVar", BaseMessage, str, AIMessageV1 +) def _get_verbosity() -> bool: diff --git a/libs/core/langchain_core/language_models/v1/__init__.py b/libs/core/langchain_core/language_models/v1/__init__.py new file mode 100644 index 00000000000..eaffaa213ba --- /dev/null +++ b/libs/core/langchain_core/language_models/v1/__init__.py @@ -0,0 +1 @@ +"""LangChain v1.0 chat models.""" diff --git a/libs/core/langchain_core/language_models/v1/chat_models.py b/libs/core/langchain_core/language_models/v1/chat_models.py new file mode 100644 index 00000000000..437fb876496 --- /dev/null +++ b/libs/core/langchain_core/language_models/v1/chat_models.py @@ -0,0 +1,909 @@ +"""Chat models for conversational AI.""" + +from __future__ import annotations + +import copy +import typing +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterator, Sequence +from operator import itemgetter +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Optional, + Union, + cast, +) + +from pydantic import ( + BaseModel, + ConfigDict, + Field, +) +from typing_extensions import override + +from langchain_core.callbacks import ( + AsyncCallbackManager, + AsyncCallbackManagerForLLMRun, + CallbackManager, + CallbackManagerForLLMRun, +) +from langchain_core.language_models._utils import _normalize_messages +from langchain_core.language_models.base import ( + BaseLanguageModel, + LangSmithParams, + LanguageModelInput, +) +from langchain_core.messages import ( + AIMessage, + BaseMessage, + convert_to_openai_image_block, + is_data_content_block, +) +from langchain_core.messages.utils import convert_to_messages_v1 +from langchain_core.messages.v1 import AIMessage as AIMessageV1 +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 +from langchain_core.messages.v1 import MessageV1, add_ai_message_chunks +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, +) +from langchain_core.prompt_values import PromptValue +from langchain_core.rate_limiters import BaseRateLimiter +from langchain_core.runnables import RunnableMap, RunnablePassthrough +from langchain_core.runnables.config import ensure_config, run_in_executor +from langchain_core.tracers._streaming import _StreamingCallbackHandler +from langchain_core.utils.function_calling import ( + convert_to_json_schema, + convert_to_openai_tool, +) +from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass + +if TYPE_CHECKING: + from langchain_core.output_parsers.base import OutputParserLike + from langchain_core.runnables import Runnable, RunnableConfig + from langchain_core.tools import BaseTool + + +def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]: + if hasattr(error, "response"): + response = error.response + metadata: dict = {} + if hasattr(response, "headers"): + try: + metadata["headers"] = dict(response.headers) + except Exception: + metadata["headers"] = None + if hasattr(response, "status_code"): + metadata["status_code"] = response.status_code + if hasattr(error, "request_id"): + metadata["request_id"] = error.request_id + generations = [AIMessageV1(content=[], response_metadata=metadata)] + else: + generations = [] + + return generations + + +def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]: + """Format messages for tracing in on_chat_model_start. + + - Update image content blocks to OpenAI Chat Completions format (backward + compatibility). + - Add "type" key to content blocks that have a single key. + + Args: + messages: List of messages to format. + + Returns: + List of messages formatted for tracing. + """ + messages_to_trace = [] + for message in messages: + message_to_trace = message + for idx, block in enumerate(message.content): + # Update image content blocks to OpenAI # Chat Completions format. + if ( + block["type"] == "image" + and is_data_content_block(block) + and block.get("source_type") != "id" + ): + if message_to_trace is message: + # Shallow copy + message_to_trace = copy.copy(message) + message_to_trace.content = list(message_to_trace.content) + + message_to_trace.content[idx] = convert_to_openai_image_block(block) + else: + pass + messages_to_trace.append(message_to_trace) + + return messages_to_trace + + +def generate_from_stream(stream: Iterator[AIMessageChunkV1]) -> AIMessageV1: + """Generate from a stream. + + Args: + stream: Iterator of AIMessageChunkV1. + + Returns: + AIMessageV1: aggregated message. + """ + generation = next(stream, None) + if generation: + generation += list(stream) + if generation is None: + msg = "No generations found in stream." + raise ValueError(msg) + return generation.to_message() + + +async def agenerate_from_stream( + stream: AsyncIterator[AIMessageChunkV1], +) -> AIMessageV1: + """Async generate from a stream. + + Args: + stream: Iterator of AIMessageChunkV1. + + Returns: + AIMessageV1: aggregated message. + """ + chunks = [chunk async for chunk in stream] + return await run_in_executor(None, generate_from_stream, iter(chunks)) + + +def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) -> dict: + if ls_structured_output_format: + try: + ls_structured_output_format_dict = { + "ls_structured_output_format": { + "kwargs": ls_structured_output_format.get("kwargs", {}), + "schema": convert_to_json_schema( + ls_structured_output_format["schema"] + ), + } + } + except ValueError: + ls_structured_output_format_dict = {} + else: + ls_structured_output_format_dict = {} + + return ls_structured_output_format_dict + + +class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): + """Base class for chat models. + + Key imperative methods: + Methods that actually call the underlying model. + + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | Method | Input | Output | Description | + +===========================+================================================================+=====================================================================+==================================================================================================+ + | `invoke` | str | list[dict | tuple | BaseMessage] | PromptValue | BaseMessage | A single chat model call. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `ainvoke` | ''' | BaseMessage | Defaults to running invoke in an async executor. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `stream` | ''' | Iterator[BaseMessageChunk] | Defaults to yielding output of invoke. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `astream` | ''' | AsyncIterator[BaseMessageChunk] | Defaults to yielding output of ainvoke. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `astream_events` | ''' | AsyncIterator[StreamEvent] | Event types: 'on_chat_model_start', 'on_chat_model_stream', 'on_chat_model_end'. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `batch` | list['''] | list[BaseMessage] | Defaults to running invoke in concurrent threads. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `abatch` | list['''] | list[BaseMessage] | Defaults to running ainvoke in concurrent threads. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `batch_as_completed` | list['''] | Iterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running invoke in concurrent threads. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `abatch_as_completed` | list['''] | AsyncIterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running ainvoke in concurrent threads. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + + This table provides a brief overview of the main imperative methods. Please see the base Runnable reference for full documentation. + + Key declarative methods: + Methods for creating another Runnable using the ChatModel. + + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | Method | Description | + +==================================+===========================================================================================================+ + | `bind_tools` | Create ChatModel that can call tools. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `with_structured_output` | Create wrapper that structures model output using schema. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `with_retry` | Create wrapper that retries model calls on failure. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `with_fallbacks` | Create wrapper that falls back to other models on failure. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `configurable_fields` | Specify init args of the model that can be configured at runtime via the RunnableConfig. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the RunnableConfig. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + + This table provides a brief overview of the main declarative methods. Please see the reference for each method for full documentation. + + Creating custom chat model: + Custom chat model implementations should inherit from this class. + Please reference the table below for information about which + methods and properties are required or optional for implementations. + + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | Method/Property | Description | Required/Optional | + +==================================+====================================================================+===================+ + | `_generate` | Use to generate a chat result from a prompt | Required | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_llm_type` (property) | Used to uniquely identify the type of the model. Used for logging. | Required | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_identifying_params` (property) | Represent model parameterization for tracing purposes. | Optional | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_stream` | Use to implement streaming | Optional | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_agenerate` | Use to implement a native async method | Optional | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_astream` | Use to implement async version of `_stream` | Optional | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + + Follow the guide for more information on how to implement a custom Chat Model: + [Guide](https://python.langchain.com/docs/how_to/custom_chat_model/). + + """ # noqa: E501 + + rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True) + "An optional rate limiter to use for limiting the number of requests." + + disable_streaming: Union[bool, Literal["tool_calling"]] = False + """Whether to disable streaming for this model. + + If streaming is bypassed, then ``stream()``/``astream()``/``astream_events()`` will + defer to ``invoke()``/``ainvoke()``. + + - If True, will always bypass streaming case. + - If ``'tool_calling'``, will bypass streaming case only when the model is called + with a ``tools`` keyword argument. In other words, LangChain will automatically + switch to non-streaming behavior (``invoke()``) only when the tools argument is + provided. This offers the best of both worlds. + - If False (default), will always use streaming case if available. + + The main reason for this flag is that code might be written using ``.stream()`` and + a user may want to swap out a given model for another model whose the implementation + does not properly support streaming. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + # --- Runnable methods --- + + @property + @override + def OutputType(self) -> Any: + """Get the output type for this runnable.""" + return AIMessageV1 + + def _convert_input(self, model_input: LanguageModelInput) -> list[MessageV1]: + if isinstance(model_input, PromptValue): + return model_input.to_messages(output_version="v1") + if isinstance(model_input, str): + return [HumanMessageV1(content=model_input)] + if isinstance(model_input, Sequence): + return convert_to_messages_v1(model_input) + msg = ( + f"Invalid input type {type(model_input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + raise ValueError(msg) + + def _should_stream( + self, + *, + async_api: bool, + run_manager: Optional[ + Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] + ] = None, + **kwargs: Any, + ) -> bool: + """Determine if a given model call should hit the streaming API.""" + sync_not_implemented = type(self)._stream == BaseChatModelV1._stream # noqa: SLF001 + async_not_implemented = type(self)._astream == BaseChatModelV1._astream # noqa: SLF001 + + # Check if streaming is implemented. + if (not async_api) and sync_not_implemented: + return False + # Note, since async falls back to sync we check both here. + if async_api and async_not_implemented and sync_not_implemented: + return False + + # Check if streaming has been disabled on this instance. + if self.disable_streaming is True: + return False + # We assume tools are passed in via "tools" kwarg in all models. + if self.disable_streaming == "tool_calling" and kwargs.get("tools"): + return False + + # Check if a runtime streaming flag has been passed in. + if "stream" in kwargs: + return kwargs["stream"] + + # Check if any streaming callback handlers have been passed in. + handlers = run_manager.handlers if run_manager else [] + return any(isinstance(h, _StreamingCallbackHandler) for h in handlers) + + @override + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AIMessageV1: + config = ensure_config(config) + messages = self._convert_input(input) + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(**kwargs) + options = {**kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + ) + (run_manager,) = callback_manager.on_chat_model_start( + {}, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=config.get("run_name"), + run_id=config.pop("run_id", None), + batch_size=1, + ) + + if self.rate_limiter: + self.rate_limiter.acquire(blocking=True) + + input_messages = _normalize_messages(messages) + + if self._should_stream(async_api=False, **kwargs): + chunks: list[AIMessageChunkV1] = [] + try: + for msg in self._stream(input_messages, **kwargs): + run_manager.on_llm_new_token(msg) + chunks.append(msg) + except BaseException as e: + run_manager.on_llm_error(e, response=_generate_response_from_error(e)) + raise + msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + else: + try: + msg = self._invoke(input_messages, **kwargs) + except BaseException as e: + run_manager.on_llm_error(e) + raise + + run_manager.on_llm_end(msg) + return msg + + @override + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AIMessage: + config = ensure_config(config) + messages = self._convert_input(input) + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(**kwargs) + options = {**kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + ) + (run_manager,) = await callback_manager.on_chat_model_start( + {}, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=config.get("run_name"), + run_id=config.pop("run_id", None), + batch_size=1, + ) + + if self.rate_limiter: + await self.rate_limiter.aacquire(blocking=True) + + # TODO: type openai image, audio, file types and permit in MessageV1 + input_messages = _normalize_messages(messages) # type: ignore[arg-type] + + if self._should_stream(async_api=True, **kwargs): + chunks: list[AIMessageChunkV1] = [] + try: + async for msg in self._astream(input_messages, **kwargs): + await run_manager.on_llm_new_token(msg) + chunks.append(msg) + except BaseException as e: + await run_manager.on_llm_error( + e, response=_generate_response_from_error(e) + ) + raise + msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + else: + try: + msg = await self._ainvoke(input_messages, **kwargs) + except BaseException as e: + await run_manager.on_llm_error( + e, response=_generate_response_from_error(e) + ) + raise + + await run_manager.on_llm_end(msg.to_message()) + return msg + + @override + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[AIMessageChunkV1]: + if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): + # model doesn't implement streaming, so use default implementation + yield cast( + "AIMessageChunkV1", + self.invoke(input, config=config, **kwargs), + ) + else: + config = ensure_config(config) + messages = self._convert_input(input) + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(**kwargs) + options = {**kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + ) + (run_manager,) = callback_manager.on_chat_model_start( + {}, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=config.get("run_name"), + run_id=config.pop("run_id", None), + batch_size=1, + ) + + chunks: list[AIMessageChunkV1] = [] + + if self.rate_limiter: + self.rate_limiter.acquire(blocking=True) + + try: + # TODO: replace this with something for new messages + input_messages = _normalize_messages(messages) + for msg in self._stream(input_messages, **kwargs): + run_manager.on_llm_new_token(msg) + chunks.append(msg) + yield msg + except BaseException as e: + run_manager.on_llm_error(e, response=_generate_response_from_error(e)) + raise + + msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + run_manager.on_llm_end(msg) + + @override + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[AIMessageChunkV1]: + if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): + # No async or sync stream is implemented, so fall back to ainvoke + yield cast( + "AIMessageChunkV1", + await self.ainvoke(input, config=config, **kwargs), + ) + return + + config = ensure_config(config) + messages = self._convert_input(input) + + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(**kwargs) + options = {**kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + ) + (run_manager,) = await callback_manager.on_chat_model_start( + {}, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=config.get("run_name"), + run_id=config.pop("run_id", None), + batch_size=1, + ) + + if self.rate_limiter: + await self.rate_limiter.aacquire(blocking=True) + + chunks: list[AIMessageChunkV1] = [] + + try: + input_messages = _normalize_messages(messages) + async for msg in self._astream( + input_messages, + **kwargs, + ): + await run_manager.on_llm_new_token(msg) + chunks.append(msg) + yield msg + except BaseException as e: + await run_manager.on_llm_error(e, response=_generate_response_from_error(e)) + raise + + msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + await run_manager.on_llm_end(msg) + + # --- Custom methods --- + + def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002 + return {} + + def _get_invocation_params( + self, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict: + params = self.dict() + params["stop"] = stop + return {**params, **kwargs} + + def _get_ls_params( + self, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> LangSmithParams: + """Get standard params for tracing.""" + # get default provider from class name + default_provider = self.__class__.__name__ + if default_provider.startswith("Chat"): + default_provider = default_provider[4:].lower() + elif default_provider.endswith("Chat"): + default_provider = default_provider[:-4] + default_provider = default_provider.lower() + + ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat") + if stop: + ls_params["ls_stop"] = stop + + # model + if hasattr(self, "model") and isinstance(self.model, str): + ls_params["ls_model_name"] = self.model + elif hasattr(self, "model_name") and isinstance(self.model_name, str): + ls_params["ls_model_name"] = self.model_name + + # temperature + if "temperature" in kwargs and isinstance(kwargs["temperature"], float): + ls_params["ls_temperature"] = kwargs["temperature"] + elif hasattr(self, "temperature") and isinstance(self.temperature, float): + ls_params["ls_temperature"] = self.temperature + + # max_tokens + if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int): + ls_params["ls_max_tokens"] = kwargs["max_tokens"] + elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int): + ls_params["ls_max_tokens"] = self.max_tokens + + return ls_params + + def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str: + params = self._get_invocation_params(stop=stop, **kwargs) + params = {**params, **kwargs} + return str(sorted(params.items())) + + def _invoke( + self, + messages: list[MessageV1], + **kwargs: Any, + ) -> AIMessage: + raise NotImplementedError + + async def _ainvoke( + self, + messages: list[MessageV1], + **kwargs: Any, + ) -> AIMessage: + return await run_in_executor( + None, + self._invoke, + messages, + **kwargs, + ) + + def _stream( + self, + messages: list[MessageV1], + **kwargs: Any, + ) -> Iterator[AIMessageChunkV1]: + raise NotImplementedError + + async def _astream( + self, + messages: list[MessageV1], + **kwargs: Any, + ) -> AsyncIterator[AIMessageChunkV1]: + iterator = await run_in_executor( + None, + self._stream, + messages, + **kwargs, + ) + done = object() + while True: + item = await run_in_executor( + None, + next, + iterator, + done, + ) + if item is done: + break + yield item # type: ignore[misc] + + @property + @abstractmethod + def _llm_type(self) -> str: + """Return type of chat model.""" + + @override + def dict(self, **kwargs: Any) -> dict: + """Return a dictionary of the LLM.""" + starter_dict = dict(self._identifying_params) + starter_dict["_type"] = self._llm_type + return starter_dict + + def bind_tools( + self, + tools: Sequence[ + Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 + ], + *, + tool_choice: Optional[Union[str]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tools to the model. + + Args: + tools: Sequence of tools to bind to the model. + tool_choice: The tool to use. If "any" then any tool can be used. + + Returns: + A Runnable that returns a message. + """ + raise NotImplementedError + + def with_structured_output( + self, + schema: Union[typing.Dict, type], # noqa: UP006 + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006 + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: + The output schema. Can be passed in as: + - an OpenAI function/tool schema, + - a JSON Schema, + - a TypedDict class, + - or a Pydantic class. + If ``schema`` is a Pydantic class then the model output will be a + Pydantic instance of that class, and the model-generated fields will be + validated by the Pydantic class. Otherwise the model output will be a + dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool` + for more on how to properly specify types and descriptions of + schema fields when specifying a Pydantic or TypedDict class. + + include_raw: + If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. + + If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs + an instance of ``schema`` (i.e., a Pydantic object). + + Otherwise, if ``include_raw`` is False then Runnable outputs a dict. + + If ``include_raw`` is True, then Runnable outputs a dict with keys: + - ``"raw"``: BaseMessage + - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. + - ``"parsing_error"``: Optional[BaseException] + + Example: Pydantic schema (include_raw=False): + .. code-block:: python + + from pydantic import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm = ChatModel(model="model-name", temperature=0) + structured_llm = llm.with_structured_output(AnswerWithJustification) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + + # -> AnswerWithJustification( + # answer='They weigh the same', + # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' + # ) + + Example: Pydantic schema (include_raw=True): + .. code-block:: python + + from pydantic import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm = ChatModel(model="model-name", temperature=0) + structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), + # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), + # 'parsing_error': None + # } + + Example: Dict schema (include_raw=False): + .. code-block:: python + + from pydantic import BaseModel + from langchain_core.utils.function_calling import convert_to_openai_tool + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + dict_schema = convert_to_openai_tool(AnswerWithJustification) + llm = ChatModel(model="model-name", temperature=0) + structured_llm = llm.with_structured_output(dict_schema) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'answer': 'They weigh the same', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' + # } + + .. versionchanged:: 0.2.26 + + Added support for TypedDict class. + """ # noqa: E501 + _ = kwargs.pop("method", None) + _ = kwargs.pop("strict", None) + if kwargs: + msg = f"Received unsupported arguments {kwargs}" + raise ValueError(msg) + + from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, + ) + + if type(self).bind_tools is BaseChatModelV1.bind_tools: + msg = "with_structured_output is not implemented for this model." + raise NotImplementedError(msg) + + llm = self.bind_tools( + [schema], + tool_choice="any", + ls_structured_output_format={ + "kwargs": {"method": "function_calling"}, + "schema": schema, + }, + ) + if isinstance(schema, type) and is_basemodel_subclass(schema): + output_parser: OutputParserLike = PydanticToolsParser( + tools=[cast("TypeBaseModel", schema)], first_tool_only=True + ) + else: + key_name = convert_to_openai_tool(schema)["function"]["name"] + output_parser = JsonOutputKeyToolsParser( + key_name=key_name, first_tool_only=True + ) + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + return llm | output_parser + + +def _gen_info_and_msg_metadata( + generation: Union[ChatGeneration, ChatGenerationChunk], +) -> dict: + return { + **(generation.generation_info or {}), + **generation.message.response_metadata, + } diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 24fdeacfaa9..8ed6407d4c1 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -299,6 +299,81 @@ def _create_message_from_message_type( return message +def _create_message_from_message_type_v1( + message_type: str, + content: str, + name: Optional[str] = None, + tool_call_id: Optional[str] = None, + tool_calls: Optional[list[dict[str, Any]]] = None, + id: Optional[str] = None, + **kwargs: Any, +) -> MessageV1: + """Create a message from a message type and content string. + + Args: + message_type: (str) the type of the message (e.g., "human", "ai", etc.). + content: (str) the content string. + name: (str) the name of the message. Default is None. + tool_call_id: (str) the tool call id. Default is None. + tool_calls: (list[dict[str, Any]]) the tool calls. Default is None. + id: (str) the id of the message. Default is None. + kwargs: (dict[str, Any]) additional keyword arguments. + + Returns: + a message of the appropriate type. + + Raises: + ValueError: if the message type is not one of "human", "user", "ai", + "assistant", "tool", "system", or "developer". + """ + kwargs: dict[str, Any] = {} + if name is not None: + kwargs["name"] = name + if tool_call_id is not None: + kwargs["tool_call_id"] = tool_call_id + if kwargs and (response_metadata := kwargs.pop("response_metadata", None)): + kwargs["response_metadata"] = response_metadata + if id is not None: + kwargs["id"] = id + if tool_calls is not None: + kwargs["tool_calls"] = [] + for tool_call in tool_calls: + # Convert OpenAI-format tool call to LangChain format. + if "function" in tool_call: + args = tool_call["function"]["arguments"] + if isinstance(args, str): + args = json.loads(args, strict=False) + kwargs["tool_calls"].append( + { + "name": tool_call["function"]["name"], + "args": args, + "id": tool_call["id"], + "type": "tool_call", + } + ) + else: + kwargs["tool_calls"].append(tool_call) + if message_type in {"human", "user"}: + message = HumanMessageV1(content=content, **kwargs) + elif message_type in {"ai", "assistant"}: + message = AIMessageV1(content=content, **kwargs) + elif message_type in {"system", "developer"}: + if message_type == "developer": + kwargs["custom_role"] = "developer" + message = SystemMessageV1(content=content, **kwargs) + elif message_type == "tool": + artifact = kwargs.pop("artifact", None) + message = ToolMessageV1(content=content, artifact=artifact, **kwargs) + else: + msg = ( + f"Unexpected message type: '{message_type}'. Use one of 'human'," + f" 'user', 'ai', 'assistant', 'function', 'tool', 'system', or 'developer'." + ) + msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE) + raise ValueError(msg) + return message + + def _convert_from_v1_message(message: MessageV1) -> BaseMessage: """Compatibility layer to convert v1 messages to current messages. @@ -403,6 +478,61 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: return message_ +def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1: + """Instantiate a message from a variety of message formats. + + The message format can be one of the following: + + - BaseMessagePromptTemplate + - BaseMessage + - 2-tuple of (role string, template); e.g., ("human", "{user_input}") + - dict: a message dict with role and content keys + - string: shorthand for ("human", template); e.g., "{user_input}" + + Args: + message: a representation of a message in one of the supported formats. + + Returns: + an instance of a message or a message template. + + Raises: + NotImplementedError: if the message type is not supported. + ValueError: if the message dict does not contain the required keys. + """ + if isinstance(message, MessageV1Types): + message_ = message + elif isinstance(message, str): + message_ = _create_message_from_message_type_v1("human", message) + elif isinstance(message, Sequence) and len(message) == 2: + # mypy doesn't realise this can't be a string given the previous branch + message_type_str, template = message # type: ignore[misc] + message_ = _create_message_from_message_type_v1(message_type_str, template) + elif isinstance(message, dict): + msg_kwargs = message.copy() + try: + try: + msg_type = msg_kwargs.pop("role") + except KeyError: + msg_type = msg_kwargs.pop("type") + # None msg content is not allowed + msg_content = msg_kwargs.pop("content") or "" + except KeyError as e: + msg = f"Message dict must contain 'role' and 'content' keys, got {message}" + msg = create_message( + message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE + ) + raise ValueError(msg) from e + message_ = _create_message_from_message_type_v1( + msg_type, msg_content, **msg_kwargs + ) + else: + msg = f"Unsupported message type: {type(message)}" + msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE) + raise NotImplementedError(msg) + + return message_ + + def convert_to_messages( messages: Union[Iterable[MessageLikeRepresentation], PromptValue], ) -> list[BaseMessage]: @@ -422,6 +552,25 @@ def convert_to_messages( return [_convert_to_message(m) for m in messages] +def convert_to_messages_v1( + messages: Union[Iterable[MessageLikeRepresentation], PromptValue], +) -> list[MessageV1]: + """Convert a sequence of messages to a list of messages. + + Args: + messages: Sequence of messages to convert. + + Returns: + list of messages (BaseMessages). + """ + # Import here to avoid circular imports + from langchain_core.prompt_values import PromptValue + + if isinstance(messages, PromptValue): + return messages.to_messages(output_version="v1") + return [_convert_to_message_v1(m) for m in messages] + + def _runnable_support(func: Callable) -> Callable: @overload def wrapped( diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py index b784cbcfe2f..9ff2eaed4ab 100644 --- a/libs/core/langchain_core/messages/v1.py +++ b/libs/core/langchain_core/messages/v1.py @@ -366,6 +366,8 @@ def add_ai_message_chunks( left: AIMessageChunk, *others: AIMessageChunk ) -> AIMessageChunk: """Add multiple AIMessageChunks together.""" + if not others: + return left content = merge_content( cast("list[str | dict[Any, Any]]", left.content), *(cast("list[str | dict[Any, Any]]", o.content) for o in others), diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index ce262797535..4d17f3b8cae 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -15,6 +15,8 @@ from langchain_core.language_models import ( ParrotFakeChatModel, ) from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import MessageV1 from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.stubs import ( _any_id_ai_message, @@ -157,13 +159,13 @@ async def test_callback_handlers() -> None: """Verify that model is implemented correctly with handlers working.""" class MyCustomAsyncHandler(AsyncCallbackHandler): - def __init__(self, store: list[str]) -> None: + def __init__(self, store: list[Union[str, AIMessageChunkV1]]) -> None: self.store = store async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -178,9 +180,11 @@ async def test_callback_handlers() -> None: @override async def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunkV1], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[list[str]] = None, @@ -194,7 +198,7 @@ async def test_callback_handlers() -> None: ] ) model = GenericFakeChatModel(messages=infinite_cycle) - tokens: list[str] = [] + tokens: list[Union[str, AIMessageChunkV1]] = [] # New model results = [ chunk