diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py
index 5365fcb9ef1..f9c6f75f307 100644
--- a/libs/core/langchain_core/callbacks/base.py
+++ b/libs/core/langchain_core/callbacks/base.py
@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import Self
+from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1
+
if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
@@ -64,9 +66,11 @@ class LLMManagerMixin:
def on_llm_new_token(
self,
- token: str,
+ token: Union[str, AIMessageChunk],
*,
- chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
+ chunk: Optional[
+ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
+ ] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
@@ -75,8 +79,8 @@ class LLMManagerMixin:
Args:
token (str): The new token.
- chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
- containing content and other information.
+ chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
+ generated chunk, containing content and other information.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
@@ -84,7 +88,7 @@ class LLMManagerMixin:
def on_llm_end(
self,
- response: LLMResult,
+ response: Union[LLMResult, AIMessage],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -93,7 +97,7 @@ class LLMManagerMixin:
"""Run when LLM ends running.
Args:
- response (LLMResult): The response which was generated.
+ response (LLMResult | AIMessage): The response which was generated.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
@@ -261,7 +265,7 @@ class CallbackManagerMixin:
def on_chat_model_start(
self,
serialized: dict[str, Any],
- messages: list[list[BaseMessage]],
+ messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -439,6 +443,9 @@ class BaseCallbackHandler(
run_inline: bool = False
"""Whether to run the callback inline."""
+ accepts_new_messages: bool = False
+ """Whether the callback accepts new message format."""
+
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
@@ -509,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
- messages: list[list[BaseMessage]],
+ messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -538,9 +545,11 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_new_token(
self,
- token: str,
+ token: Union[str, AIMessageChunk],
*,
- chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
+ chunk: Optional[
+ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
+ ] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
@@ -550,8 +559,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
Args:
token (str): The new token.
- chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
- containing content and other information.
+ chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
+ generated chunk, containing content and other information.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[list[str]]): The tags.
@@ -560,7 +569,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_end(
self,
- response: LLMResult,
+ response: Union[LLMResult, AIMessage],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -570,7 +579,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
"""Run when LLM ends running.
Args:
- response (LLMResult): The response which was generated.
+ response (LLMResult | AIMessage): The response which was generated.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[list[str]]): The tags.
@@ -594,8 +603,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
parent_run_id: The parent run ID. This is the ID of the parent run.
tags: The tags.
kwargs (Any): Additional keyword arguments.
- - response (LLMResult): The response which was generated before
- the error occurred.
+ - response (LLMResult | AIMessage): The response which was generated
+ before the error occurred.
"""
async def on_chain_start(
diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py
index 56fc1bb67ba..bca2d800696 100644
--- a/libs/core/langchain_core/callbacks/manager.py
+++ b/libs/core/langchain_core/callbacks/manager.py
@@ -11,6 +11,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from contextvars import copy_context
+from dataclasses import is_dataclass
from typing import (
TYPE_CHECKING,
Any,
@@ -37,6 +38,8 @@ from langchain_core.callbacks.base import (
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.messages import BaseMessage, get_buffer_string
+from langchain_core.messages.v1 import AIMessage, AIMessageChunk
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
from langchain_core.tracers.schemas import Run
from langchain_core.utils.env import env_var_is_set
@@ -47,7 +50,7 @@ if TYPE_CHECKING:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.documents import Document
- from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
+ from langchain_core.outputs import GenerationChunk
from langchain_core.runnables.config import RunnableConfig
logger = logging.getLogger(__name__)
@@ -243,6 +246,22 @@ def shielded(func: Func) -> Func:
return cast("Func", wrapped)
+def _convert_llm_events(
+ event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
+) -> None:
+ if event_name == "on_chat_model_start" and isinstance(args[1], list):
+ for idx, item in enumerate(args[1]):
+ if is_dataclass(item):
+ args[1][idx] = item # convert to old message
+ elif event_name == "on_llm_new_token" and is_dataclass(args[0]):
+ kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0])
+ args[0] = args[0].text
+ elif event_name == "on_llm_end" and is_dataclass(args[0]):
+ args[0] = LLMResult(
+ generations=[[ChatGeneration(text=args[0].text, message=args[0])]]
+ )
+
+
def handle_event(
handlers: list[BaseCallbackHandler],
event_name: str,
@@ -271,6 +290,8 @@ def handle_event(
if ignore_condition_name is None or not getattr(
handler, ignore_condition_name
):
+ if not handler.accepts_new_messages:
+ _convert_llm_events(event_name, args, kwargs)
event = getattr(handler, event_name)(*args, **kwargs)
if asyncio.iscoroutine(event):
coros.append(event)
@@ -365,6 +386,8 @@ async def _ahandle_event_for_handler(
) -> None:
try:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
+ if not handler.accepts_new_messages:
+ _convert_llm_events(event_name, args, kwargs)
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
@@ -672,9 +695,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
def on_llm_new_token(
self,
- token: str,
+ token: Union[str, AIMessageChunk],
*,
- chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
+ chunk: Optional[
+ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
+ ] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
@@ -699,11 +724,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
**kwargs,
)
- def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
"""Run when LLM ends running.
Args:
- response (LLMResult): The LLM result.
+ response (LLMResult | AIMessage): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
@@ -729,8 +754,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
kwargs (Any): Additional keyword arguments.
- - response (LLMResult): The response which was generated before
- the error occurred.
+ - response (LLMResult | AIMessage): The response which was generated
+ before the error occurred.
"""
if not self.handlers:
return
@@ -768,9 +793,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_new_token(
self,
- token: str,
+ token: Union[str, AIMessageChunk],
*,
- chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
+ chunk: Optional[
+ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
+ ] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
@@ -796,11 +823,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
)
@shielded
- async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ async def on_llm_end(
+ self, response: Union[LLMResult, AIMessage], **kwargs: Any
+ ) -> None:
"""Run when LLM ends running.
Args:
- response (LLMResult): The LLM result.
+ response (LLMResult | AIMessage): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
@@ -827,11 +856,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
kwargs (Any): Additional keyword arguments.
- - response (LLMResult): The response which was generated before
- the error occurred.
-
-
-
+ - response (LLMResult | AIMessage): The response which was generated
+ before the error occurred.
"""
if not self.handlers:
return
diff --git a/libs/core/langchain_core/language_models/_utils.py b/libs/core/langchain_core/language_models/_utils.py
index 883f8c855ea..cc5d39a18c3 100644
--- a/libs/core/langchain_core/language_models/_utils.py
+++ b/libs/core/langchain_core/language_models/_utils.py
@@ -1,3 +1,4 @@
+import copy
import re
from collections.abc import Sequence
from typing import Optional
@@ -128,7 +129,10 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
and _is_openai_data_block(block)
):
if formatted_message is message:
- formatted_message = message.model_copy()
+ if isinstance(message, BaseMessage):
+ formatted_message = message.model_copy()
+ else:
+ formatted_message = copy.copy(message)
# Also shallow-copy content
formatted_message.content = list(formatted_message.content)
diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py
index a9e7e4e64cc..1fef01a6859 100644
--- a/libs/core/langchain_core/language_models/base.py
+++ b/libs/core/langchain_core/language_models/base.py
@@ -28,6 +28,7 @@ from langchain_core.messages import (
MessageLikeRepresentation,
get_buffer_string,
)
+from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
@@ -85,7 +86,9 @@ def _get_token_ids_default_method(text: str) -> list[int]:
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
LanguageModelOutput = Union[BaseMessage, str]
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
-LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
+LanguageModelOutputVar = TypeVar(
+ "LanguageModelOutputVar", BaseMessage, str, AIMessageV1
+)
def _get_verbosity() -> bool:
diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py
index d37ca232052..e96003d26fe 100644
--- a/libs/core/langchain_core/language_models/chat_models.py
+++ b/libs/core/langchain_core/language_models/chat_models.py
@@ -58,7 +58,9 @@ from langchain_core.messages import (
is_data_content_block,
message_chunk_to_message,
)
+from langchain_core.messages import content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX
+from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
@@ -220,6 +222,23 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
return ls_structured_output_format_dict
+def _convert_to_v1(message: AIMessage) -> AIMessageV1:
+ """Best-effort conversion of a V0 AIMessage to V1."""
+ if isinstance(message.content, str):
+ content: list[types.ContentBlock] = []
+ if message.content:
+ content = [{"type": "text", "text": message.content}]
+
+ for tool_call in message.tool_calls:
+ content.append(tool_call)
+
+ return AIMessageV1(
+ content=content,
+ usage_metadata=message.usage_metadata,
+ response_metadata=message.response_metadata,
+ )
+
+
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for chat models.
@@ -328,6 +347,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
does not properly support streaming.
"""
+ output_version: str = "v0"
+ """Version of AIMessage output format to use.
+
+ This field is used to roll-out new output formats for chat model AIMessages
+ in a backwards-compatible way.
+
+ ``'v1'`` standardizes output format using a list of typed ContentBlock dicts. We
+ recommend this for new applications.
+
+ All chat models currently support the default of ``"v0"``.
+
+ .. versionadded:: 0.4
+ """
+
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: dict) -> Any:
diff --git a/libs/core/langchain_core/language_models/v1/__init__.py b/libs/core/langchain_core/language_models/v1/__init__.py
new file mode 100644
index 00000000000..eaffaa213ba
--- /dev/null
+++ b/libs/core/langchain_core/language_models/v1/__init__.py
@@ -0,0 +1 @@
+"""LangChain v1.0 chat models."""
diff --git a/libs/core/langchain_core/language_models/v1/chat_models.py b/libs/core/langchain_core/language_models/v1/chat_models.py
new file mode 100644
index 00000000000..437fb876496
--- /dev/null
+++ b/libs/core/langchain_core/language_models/v1/chat_models.py
@@ -0,0 +1,909 @@
+"""Chat models for conversational AI."""
+
+from __future__ import annotations
+
+import copy
+import typing
+from abc import ABC, abstractmethod
+from collections.abc import AsyncIterator, Iterator, Sequence
+from operator import itemgetter
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Literal,
+ Optional,
+ Union,
+ cast,
+)
+
+from pydantic import (
+ BaseModel,
+ ConfigDict,
+ Field,
+)
+from typing_extensions import override
+
+from langchain_core.callbacks import (
+ AsyncCallbackManager,
+ AsyncCallbackManagerForLLMRun,
+ CallbackManager,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models._utils import _normalize_messages
+from langchain_core.language_models.base import (
+ BaseLanguageModel,
+ LangSmithParams,
+ LanguageModelInput,
+)
+from langchain_core.messages import (
+ AIMessage,
+ BaseMessage,
+ convert_to_openai_image_block,
+ is_data_content_block,
+)
+from langchain_core.messages.utils import convert_to_messages_v1
+from langchain_core.messages.v1 import AIMessage as AIMessageV1
+from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
+from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
+from langchain_core.messages.v1 import MessageV1, add_ai_message_chunks
+from langchain_core.outputs import (
+ ChatGeneration,
+ ChatGenerationChunk,
+)
+from langchain_core.prompt_values import PromptValue
+from langchain_core.rate_limiters import BaseRateLimiter
+from langchain_core.runnables import RunnableMap, RunnablePassthrough
+from langchain_core.runnables.config import ensure_config, run_in_executor
+from langchain_core.tracers._streaming import _StreamingCallbackHandler
+from langchain_core.utils.function_calling import (
+ convert_to_json_schema,
+ convert_to_openai_tool,
+)
+from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
+
+if TYPE_CHECKING:
+ from langchain_core.output_parsers.base import OutputParserLike
+ from langchain_core.runnables import Runnable, RunnableConfig
+ from langchain_core.tools import BaseTool
+
+
+def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
+ if hasattr(error, "response"):
+ response = error.response
+ metadata: dict = {}
+ if hasattr(response, "headers"):
+ try:
+ metadata["headers"] = dict(response.headers)
+ except Exception:
+ metadata["headers"] = None
+ if hasattr(response, "status_code"):
+ metadata["status_code"] = response.status_code
+ if hasattr(error, "request_id"):
+ metadata["request_id"] = error.request_id
+ generations = [AIMessageV1(content=[], response_metadata=metadata)]
+ else:
+ generations = []
+
+ return generations
+
+
+def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
+ """Format messages for tracing in on_chat_model_start.
+
+ - Update image content blocks to OpenAI Chat Completions format (backward
+ compatibility).
+ - Add "type" key to content blocks that have a single key.
+
+ Args:
+ messages: List of messages to format.
+
+ Returns:
+ List of messages formatted for tracing.
+ """
+ messages_to_trace = []
+ for message in messages:
+ message_to_trace = message
+ for idx, block in enumerate(message.content):
+ # Update image content blocks to OpenAI # Chat Completions format.
+ if (
+ block["type"] == "image"
+ and is_data_content_block(block)
+ and block.get("source_type") != "id"
+ ):
+ if message_to_trace is message:
+ # Shallow copy
+ message_to_trace = copy.copy(message)
+ message_to_trace.content = list(message_to_trace.content)
+
+ message_to_trace.content[idx] = convert_to_openai_image_block(block)
+ else:
+ pass
+ messages_to_trace.append(message_to_trace)
+
+ return messages_to_trace
+
+
+def generate_from_stream(stream: Iterator[AIMessageChunkV1]) -> AIMessageV1:
+ """Generate from a stream.
+
+ Args:
+ stream: Iterator of AIMessageChunkV1.
+
+ Returns:
+ AIMessageV1: aggregated message.
+ """
+ generation = next(stream, None)
+ if generation:
+ generation += list(stream)
+ if generation is None:
+ msg = "No generations found in stream."
+ raise ValueError(msg)
+ return generation.to_message()
+
+
+async def agenerate_from_stream(
+ stream: AsyncIterator[AIMessageChunkV1],
+) -> AIMessageV1:
+ """Async generate from a stream.
+
+ Args:
+ stream: Iterator of AIMessageChunkV1.
+
+ Returns:
+ AIMessageV1: aggregated message.
+ """
+ chunks = [chunk async for chunk in stream]
+ return await run_in_executor(None, generate_from_stream, iter(chunks))
+
+
+def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) -> dict:
+ if ls_structured_output_format:
+ try:
+ ls_structured_output_format_dict = {
+ "ls_structured_output_format": {
+ "kwargs": ls_structured_output_format.get("kwargs", {}),
+ "schema": convert_to_json_schema(
+ ls_structured_output_format["schema"]
+ ),
+ }
+ }
+ except ValueError:
+ ls_structured_output_format_dict = {}
+ else:
+ ls_structured_output_format_dict = {}
+
+ return ls_structured_output_format_dict
+
+
+class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
+ """Base class for chat models.
+
+ Key imperative methods:
+ Methods that actually call the underlying model.
+
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | Method | Input | Output | Description |
+ +===========================+================================================================+=====================================================================+==================================================================================================+
+ | `invoke` | str | list[dict | tuple | BaseMessage] | PromptValue | BaseMessage | A single chat model call. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | `ainvoke` | ''' | BaseMessage | Defaults to running invoke in an async executor. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | `stream` | ''' | Iterator[BaseMessageChunk] | Defaults to yielding output of invoke. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | `astream` | ''' | AsyncIterator[BaseMessageChunk] | Defaults to yielding output of ainvoke. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | `astream_events` | ''' | AsyncIterator[StreamEvent] | Event types: 'on_chat_model_start', 'on_chat_model_stream', 'on_chat_model_end'. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | `batch` | list['''] | list[BaseMessage] | Defaults to running invoke in concurrent threads. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | `abatch` | list['''] | list[BaseMessage] | Defaults to running ainvoke in concurrent threads. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | `batch_as_completed` | list['''] | Iterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running invoke in concurrent threads. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+ | `abatch_as_completed` | list['''] | AsyncIterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running ainvoke in concurrent threads. |
+ +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
+
+ This table provides a brief overview of the main imperative methods. Please see the base Runnable reference for full documentation.
+
+ Key declarative methods:
+ Methods for creating another Runnable using the ChatModel.
+
+ +----------------------------------+-----------------------------------------------------------------------------------------------------------+
+ | Method | Description |
+ +==================================+===========================================================================================================+
+ | `bind_tools` | Create ChatModel that can call tools. |
+ +----------------------------------+-----------------------------------------------------------------------------------------------------------+
+ | `with_structured_output` | Create wrapper that structures model output using schema. |
+ +----------------------------------+-----------------------------------------------------------------------------------------------------------+
+ | `with_retry` | Create wrapper that retries model calls on failure. |
+ +----------------------------------+-----------------------------------------------------------------------------------------------------------+
+ | `with_fallbacks` | Create wrapper that falls back to other models on failure. |
+ +----------------------------------+-----------------------------------------------------------------------------------------------------------+
+ | `configurable_fields` | Specify init args of the model that can be configured at runtime via the RunnableConfig. |
+ +----------------------------------+-----------------------------------------------------------------------------------------------------------+
+ | `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the RunnableConfig. |
+ +----------------------------------+-----------------------------------------------------------------------------------------------------------+
+
+ This table provides a brief overview of the main declarative methods. Please see the reference for each method for full documentation.
+
+ Creating custom chat model:
+ Custom chat model implementations should inherit from this class.
+ Please reference the table below for information about which
+ methods and properties are required or optional for implementations.
+
+ +----------------------------------+--------------------------------------------------------------------+-------------------+
+ | Method/Property | Description | Required/Optional |
+ +==================================+====================================================================+===================+
+ | `_generate` | Use to generate a chat result from a prompt | Required |
+ +----------------------------------+--------------------------------------------------------------------+-------------------+
+ | `_llm_type` (property) | Used to uniquely identify the type of the model. Used for logging. | Required |
+ +----------------------------------+--------------------------------------------------------------------+-------------------+
+ | `_identifying_params` (property) | Represent model parameterization for tracing purposes. | Optional |
+ +----------------------------------+--------------------------------------------------------------------+-------------------+
+ | `_stream` | Use to implement streaming | Optional |
+ +----------------------------------+--------------------------------------------------------------------+-------------------+
+ | `_agenerate` | Use to implement a native async method | Optional |
+ +----------------------------------+--------------------------------------------------------------------+-------------------+
+ | `_astream` | Use to implement async version of `_stream` | Optional |
+ +----------------------------------+--------------------------------------------------------------------+-------------------+
+
+ Follow the guide for more information on how to implement a custom Chat Model:
+ [Guide](https://python.langchain.com/docs/how_to/custom_chat_model/).
+
+ """ # noqa: E501
+
+ rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
+ "An optional rate limiter to use for limiting the number of requests."
+
+ disable_streaming: Union[bool, Literal["tool_calling"]] = False
+ """Whether to disable streaming for this model.
+
+ If streaming is bypassed, then ``stream()``/``astream()``/``astream_events()`` will
+ defer to ``invoke()``/``ainvoke()``.
+
+ - If True, will always bypass streaming case.
+ - If ``'tool_calling'``, will bypass streaming case only when the model is called
+ with a ``tools`` keyword argument. In other words, LangChain will automatically
+ switch to non-streaming behavior (``invoke()``) only when the tools argument is
+ provided. This offers the best of both worlds.
+ - If False (default), will always use streaming case if available.
+
+ The main reason for this flag is that code might be written using ``.stream()`` and
+ a user may want to swap out a given model for another model whose the implementation
+ does not properly support streaming.
+ """
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ )
+
+ # --- Runnable methods ---
+
+ @property
+ @override
+ def OutputType(self) -> Any:
+ """Get the output type for this runnable."""
+ return AIMessageV1
+
+ def _convert_input(self, model_input: LanguageModelInput) -> list[MessageV1]:
+ if isinstance(model_input, PromptValue):
+ return model_input.to_messages(output_version="v1")
+ if isinstance(model_input, str):
+ return [HumanMessageV1(content=model_input)]
+ if isinstance(model_input, Sequence):
+ return convert_to_messages_v1(model_input)
+ msg = (
+ f"Invalid input type {type(model_input)}. "
+ "Must be a PromptValue, str, or list of BaseMessages."
+ )
+ raise ValueError(msg)
+
+ def _should_stream(
+ self,
+ *,
+ async_api: bool,
+ run_manager: Optional[
+ Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
+ ] = None,
+ **kwargs: Any,
+ ) -> bool:
+ """Determine if a given model call should hit the streaming API."""
+ sync_not_implemented = type(self)._stream == BaseChatModelV1._stream # noqa: SLF001
+ async_not_implemented = type(self)._astream == BaseChatModelV1._astream # noqa: SLF001
+
+ # Check if streaming is implemented.
+ if (not async_api) and sync_not_implemented:
+ return False
+ # Note, since async falls back to sync we check both here.
+ if async_api and async_not_implemented and sync_not_implemented:
+ return False
+
+ # Check if streaming has been disabled on this instance.
+ if self.disable_streaming is True:
+ return False
+ # We assume tools are passed in via "tools" kwarg in all models.
+ if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
+ return False
+
+ # Check if a runtime streaming flag has been passed in.
+ if "stream" in kwargs:
+ return kwargs["stream"]
+
+ # Check if any streaming callback handlers have been passed in.
+ handlers = run_manager.handlers if run_manager else []
+ return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
+
+ @override
+ def invoke(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ **kwargs: Any,
+ ) -> AIMessageV1:
+ config = ensure_config(config)
+ messages = self._convert_input(input)
+ ls_structured_output_format = kwargs.pop(
+ "ls_structured_output_format", None
+ ) or kwargs.pop("structured_output_format", None)
+ ls_structured_output_format_dict = _format_ls_structured_output(
+ ls_structured_output_format
+ )
+
+ params = self._get_invocation_params(**kwargs)
+ options = {**kwargs, **ls_structured_output_format_dict}
+ inheritable_metadata = {
+ **(config.get("metadata") or {}),
+ **self._get_ls_params(**kwargs),
+ }
+ callback_manager = CallbackManager.configure(
+ config.get("callbacks"),
+ self.callbacks,
+ self.verbose,
+ config.get("tags"),
+ self.tags,
+ inheritable_metadata,
+ self.metadata,
+ )
+ (run_manager,) = callback_manager.on_chat_model_start(
+ {},
+ [_format_for_tracing(messages)],
+ invocation_params=params,
+ options=options,
+ name=config.get("run_name"),
+ run_id=config.pop("run_id", None),
+ batch_size=1,
+ )
+
+ if self.rate_limiter:
+ self.rate_limiter.acquire(blocking=True)
+
+ input_messages = _normalize_messages(messages)
+
+ if self._should_stream(async_api=False, **kwargs):
+ chunks: list[AIMessageChunkV1] = []
+ try:
+ for msg in self._stream(input_messages, **kwargs):
+ run_manager.on_llm_new_token(msg)
+ chunks.append(msg)
+ except BaseException as e:
+ run_manager.on_llm_error(e, response=_generate_response_from_error(e))
+ raise
+ msg = add_ai_message_chunks(chunks[0], *chunks[1:])
+ else:
+ try:
+ msg = self._invoke(input_messages, **kwargs)
+ except BaseException as e:
+ run_manager.on_llm_error(e)
+ raise
+
+ run_manager.on_llm_end(msg)
+ return msg
+
+ @override
+ async def ainvoke(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ **kwargs: Any,
+ ) -> AIMessage:
+ config = ensure_config(config)
+ messages = self._convert_input(input)
+ ls_structured_output_format = kwargs.pop(
+ "ls_structured_output_format", None
+ ) or kwargs.pop("structured_output_format", None)
+ ls_structured_output_format_dict = _format_ls_structured_output(
+ ls_structured_output_format
+ )
+
+ params = self._get_invocation_params(**kwargs)
+ options = {**kwargs, **ls_structured_output_format_dict}
+ inheritable_metadata = {
+ **(config.get("metadata") or {}),
+ **self._get_ls_params(**kwargs),
+ }
+ callback_manager = AsyncCallbackManager.configure(
+ config.get("callbacks"),
+ self.callbacks,
+ self.verbose,
+ config.get("tags"),
+ self.tags,
+ inheritable_metadata,
+ self.metadata,
+ )
+ (run_manager,) = await callback_manager.on_chat_model_start(
+ {},
+ [_format_for_tracing(messages)],
+ invocation_params=params,
+ options=options,
+ name=config.get("run_name"),
+ run_id=config.pop("run_id", None),
+ batch_size=1,
+ )
+
+ if self.rate_limiter:
+ await self.rate_limiter.aacquire(blocking=True)
+
+ # TODO: type openai image, audio, file types and permit in MessageV1
+ input_messages = _normalize_messages(messages) # type: ignore[arg-type]
+
+ if self._should_stream(async_api=True, **kwargs):
+ chunks: list[AIMessageChunkV1] = []
+ try:
+ async for msg in self._astream(input_messages, **kwargs):
+ await run_manager.on_llm_new_token(msg)
+ chunks.append(msg)
+ except BaseException as e:
+ await run_manager.on_llm_error(
+ e, response=_generate_response_from_error(e)
+ )
+ raise
+ msg = add_ai_message_chunks(chunks[0], *chunks[1:])
+ else:
+ try:
+ msg = await self._ainvoke(input_messages, **kwargs)
+ except BaseException as e:
+ await run_manager.on_llm_error(
+ e, response=_generate_response_from_error(e)
+ )
+ raise
+
+ await run_manager.on_llm_end(msg.to_message())
+ return msg
+
+ @override
+ def stream(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ **kwargs: Any,
+ ) -> Iterator[AIMessageChunkV1]:
+ if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
+ # model doesn't implement streaming, so use default implementation
+ yield cast(
+ "AIMessageChunkV1",
+ self.invoke(input, config=config, **kwargs),
+ )
+ else:
+ config = ensure_config(config)
+ messages = self._convert_input(input)
+ ls_structured_output_format = kwargs.pop(
+ "ls_structured_output_format", None
+ ) or kwargs.pop("structured_output_format", None)
+ ls_structured_output_format_dict = _format_ls_structured_output(
+ ls_structured_output_format
+ )
+
+ params = self._get_invocation_params(**kwargs)
+ options = {**kwargs, **ls_structured_output_format_dict}
+ inheritable_metadata = {
+ **(config.get("metadata") or {}),
+ **self._get_ls_params(**kwargs),
+ }
+ callback_manager = CallbackManager.configure(
+ config.get("callbacks"),
+ self.callbacks,
+ self.verbose,
+ config.get("tags"),
+ self.tags,
+ inheritable_metadata,
+ self.metadata,
+ )
+ (run_manager,) = callback_manager.on_chat_model_start(
+ {},
+ [_format_for_tracing(messages)],
+ invocation_params=params,
+ options=options,
+ name=config.get("run_name"),
+ run_id=config.pop("run_id", None),
+ batch_size=1,
+ )
+
+ chunks: list[AIMessageChunkV1] = []
+
+ if self.rate_limiter:
+ self.rate_limiter.acquire(blocking=True)
+
+ try:
+ # TODO: replace this with something for new messages
+ input_messages = _normalize_messages(messages)
+ for msg in self._stream(input_messages, **kwargs):
+ run_manager.on_llm_new_token(msg)
+ chunks.append(msg)
+ yield msg
+ except BaseException as e:
+ run_manager.on_llm_error(e, response=_generate_response_from_error(e))
+ raise
+
+ msg = add_ai_message_chunks(chunks[0], *chunks[1:])
+ run_manager.on_llm_end(msg)
+
+ @override
+ async def astream(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[AIMessageChunkV1]:
+ if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
+ # No async or sync stream is implemented, so fall back to ainvoke
+ yield cast(
+ "AIMessageChunkV1",
+ await self.ainvoke(input, config=config, **kwargs),
+ )
+ return
+
+ config = ensure_config(config)
+ messages = self._convert_input(input)
+
+ ls_structured_output_format = kwargs.pop(
+ "ls_structured_output_format", None
+ ) or kwargs.pop("structured_output_format", None)
+ ls_structured_output_format_dict = _format_ls_structured_output(
+ ls_structured_output_format
+ )
+
+ params = self._get_invocation_params(**kwargs)
+ options = {**kwargs, **ls_structured_output_format_dict}
+ inheritable_metadata = {
+ **(config.get("metadata") or {}),
+ **self._get_ls_params(**kwargs),
+ }
+ callback_manager = AsyncCallbackManager.configure(
+ config.get("callbacks"),
+ self.callbacks,
+ self.verbose,
+ config.get("tags"),
+ self.tags,
+ inheritable_metadata,
+ self.metadata,
+ )
+ (run_manager,) = await callback_manager.on_chat_model_start(
+ {},
+ [_format_for_tracing(messages)],
+ invocation_params=params,
+ options=options,
+ name=config.get("run_name"),
+ run_id=config.pop("run_id", None),
+ batch_size=1,
+ )
+
+ if self.rate_limiter:
+ await self.rate_limiter.aacquire(blocking=True)
+
+ chunks: list[AIMessageChunkV1] = []
+
+ try:
+ input_messages = _normalize_messages(messages)
+ async for msg in self._astream(
+ input_messages,
+ **kwargs,
+ ):
+ await run_manager.on_llm_new_token(msg)
+ chunks.append(msg)
+ yield msg
+ except BaseException as e:
+ await run_manager.on_llm_error(e, response=_generate_response_from_error(e))
+ raise
+
+ msg = add_ai_message_chunks(chunks[0], *chunks[1:])
+ await run_manager.on_llm_end(msg)
+
+ # --- Custom methods ---
+
+ def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002
+ return {}
+
+ def _get_invocation_params(
+ self,
+ stop: Optional[list[str]] = None,
+ **kwargs: Any,
+ ) -> dict:
+ params = self.dict()
+ params["stop"] = stop
+ return {**params, **kwargs}
+
+ def _get_ls_params(
+ self,
+ stop: Optional[list[str]] = None,
+ **kwargs: Any,
+ ) -> LangSmithParams:
+ """Get standard params for tracing."""
+ # get default provider from class name
+ default_provider = self.__class__.__name__
+ if default_provider.startswith("Chat"):
+ default_provider = default_provider[4:].lower()
+ elif default_provider.endswith("Chat"):
+ default_provider = default_provider[:-4]
+ default_provider = default_provider.lower()
+
+ ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
+ if stop:
+ ls_params["ls_stop"] = stop
+
+ # model
+ if hasattr(self, "model") and isinstance(self.model, str):
+ ls_params["ls_model_name"] = self.model
+ elif hasattr(self, "model_name") and isinstance(self.model_name, str):
+ ls_params["ls_model_name"] = self.model_name
+
+ # temperature
+ if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
+ ls_params["ls_temperature"] = kwargs["temperature"]
+ elif hasattr(self, "temperature") and isinstance(self.temperature, float):
+ ls_params["ls_temperature"] = self.temperature
+
+ # max_tokens
+ if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
+ ls_params["ls_max_tokens"] = kwargs["max_tokens"]
+ elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
+ ls_params["ls_max_tokens"] = self.max_tokens
+
+ return ls_params
+
+ def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
+ params = self._get_invocation_params(stop=stop, **kwargs)
+ params = {**params, **kwargs}
+ return str(sorted(params.items()))
+
+ def _invoke(
+ self,
+ messages: list[MessageV1],
+ **kwargs: Any,
+ ) -> AIMessage:
+ raise NotImplementedError
+
+ async def _ainvoke(
+ self,
+ messages: list[MessageV1],
+ **kwargs: Any,
+ ) -> AIMessage:
+ return await run_in_executor(
+ None,
+ self._invoke,
+ messages,
+ **kwargs,
+ )
+
+ def _stream(
+ self,
+ messages: list[MessageV1],
+ **kwargs: Any,
+ ) -> Iterator[AIMessageChunkV1]:
+ raise NotImplementedError
+
+ async def _astream(
+ self,
+ messages: list[MessageV1],
+ **kwargs: Any,
+ ) -> AsyncIterator[AIMessageChunkV1]:
+ iterator = await run_in_executor(
+ None,
+ self._stream,
+ messages,
+ **kwargs,
+ )
+ done = object()
+ while True:
+ item = await run_in_executor(
+ None,
+ next,
+ iterator,
+ done,
+ )
+ if item is done:
+ break
+ yield item # type: ignore[misc]
+
+ @property
+ @abstractmethod
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+
+ @override
+ def dict(self, **kwargs: Any) -> dict:
+ """Return a dictionary of the LLM."""
+ starter_dict = dict(self._identifying_params)
+ starter_dict["_type"] = self._llm_type
+ return starter_dict
+
+ def bind_tools(
+ self,
+ tools: Sequence[
+ Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
+ ],
+ *,
+ tool_choice: Optional[Union[str]] = None,
+ **kwargs: Any,
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
+ """Bind tools to the model.
+
+ Args:
+ tools: Sequence of tools to bind to the model.
+ tool_choice: The tool to use. If "any" then any tool can be used.
+
+ Returns:
+ A Runnable that returns a message.
+ """
+ raise NotImplementedError
+
+ def with_structured_output(
+ self,
+ schema: Union[typing.Dict, type], # noqa: UP006
+ *,
+ include_raw: bool = False,
+ **kwargs: Any,
+ ) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
+ """Model wrapper that returns outputs formatted to match the given schema.
+
+ Args:
+ schema:
+ The output schema. Can be passed in as:
+ - an OpenAI function/tool schema,
+ - a JSON Schema,
+ - a TypedDict class,
+ - or a Pydantic class.
+ If ``schema`` is a Pydantic class then the model output will be a
+ Pydantic instance of that class, and the model-generated fields will be
+ validated by the Pydantic class. Otherwise the model output will be a
+ dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
+ for more on how to properly specify types and descriptions of
+ schema fields when specifying a Pydantic or TypedDict class.
+
+ include_raw:
+ If False then only the parsed structured output is returned. If
+ an error occurs during model output parsing it will be raised. If True
+ then both the raw model response (a BaseMessage) and the parsed model
+ response will be returned. If an error occurs during output parsing it
+ will be caught and returned as well. The final output is always a dict
+ with keys "raw", "parsed", and "parsing_error".
+
+ Returns:
+ A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
+
+ If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
+ an instance of ``schema`` (i.e., a Pydantic object).
+
+ Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
+
+ If ``include_raw`` is True, then Runnable outputs a dict with keys:
+ - ``"raw"``: BaseMessage
+ - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
+ - ``"parsing_error"``: Optional[BaseException]
+
+ Example: Pydantic schema (include_raw=False):
+ .. code-block:: python
+
+ from pydantic import BaseModel
+
+ class AnswerWithJustification(BaseModel):
+ '''An answer to the user question along with justification for the answer.'''
+ answer: str
+ justification: str
+
+ llm = ChatModel(model="model-name", temperature=0)
+ structured_llm = llm.with_structured_output(AnswerWithJustification)
+
+ structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
+
+ # -> AnswerWithJustification(
+ # answer='They weigh the same',
+ # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
+ # )
+
+ Example: Pydantic schema (include_raw=True):
+ .. code-block:: python
+
+ from pydantic import BaseModel
+
+ class AnswerWithJustification(BaseModel):
+ '''An answer to the user question along with justification for the answer.'''
+ answer: str
+ justification: str
+
+ llm = ChatModel(model="model-name", temperature=0)
+ structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
+
+ structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
+ # -> {
+ # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
+ # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
+ # 'parsing_error': None
+ # }
+
+ Example: Dict schema (include_raw=False):
+ .. code-block:: python
+
+ from pydantic import BaseModel
+ from langchain_core.utils.function_calling import convert_to_openai_tool
+
+ class AnswerWithJustification(BaseModel):
+ '''An answer to the user question along with justification for the answer.'''
+ answer: str
+ justification: str
+
+ dict_schema = convert_to_openai_tool(AnswerWithJustification)
+ llm = ChatModel(model="model-name", temperature=0)
+ structured_llm = llm.with_structured_output(dict_schema)
+
+ structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
+ # -> {
+ # 'answer': 'They weigh the same',
+ # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
+ # }
+
+ .. versionchanged:: 0.2.26
+
+ Added support for TypedDict class.
+ """ # noqa: E501
+ _ = kwargs.pop("method", None)
+ _ = kwargs.pop("strict", None)
+ if kwargs:
+ msg = f"Received unsupported arguments {kwargs}"
+ raise ValueError(msg)
+
+ from langchain_core.output_parsers.openai_tools import (
+ JsonOutputKeyToolsParser,
+ PydanticToolsParser,
+ )
+
+ if type(self).bind_tools is BaseChatModelV1.bind_tools:
+ msg = "with_structured_output is not implemented for this model."
+ raise NotImplementedError(msg)
+
+ llm = self.bind_tools(
+ [schema],
+ tool_choice="any",
+ ls_structured_output_format={
+ "kwargs": {"method": "function_calling"},
+ "schema": schema,
+ },
+ )
+ if isinstance(schema, type) and is_basemodel_subclass(schema):
+ output_parser: OutputParserLike = PydanticToolsParser(
+ tools=[cast("TypeBaseModel", schema)], first_tool_only=True
+ )
+ else:
+ key_name = convert_to_openai_tool(schema)["function"]["name"]
+ output_parser = JsonOutputKeyToolsParser(
+ key_name=key_name, first_tool_only=True
+ )
+ if include_raw:
+ parser_assign = RunnablePassthrough.assign(
+ parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
+ )
+ parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
+ parser_with_fallback = parser_assign.with_fallbacks(
+ [parser_none], exception_key="parsing_error"
+ )
+ return RunnableMap(raw=llm) | parser_with_fallback
+ return llm | output_parser
+
+
+def _gen_info_and_msg_metadata(
+ generation: Union[ChatGeneration, ChatGenerationChunk],
+) -> dict:
+ return {
+ **(generation.generation_info or {}),
+ **generation.message.response_metadata,
+ }
diff --git a/libs/core/langchain_core/messages/__init__.py b/libs/core/langchain_core/messages/__init__.py
index fe87e964af2..0faf0447295 100644
--- a/libs/core/langchain_core/messages/__init__.py
+++ b/libs/core/langchain_core/messages/__init__.py
@@ -33,6 +33,24 @@ if TYPE_CHECKING:
)
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.content_blocks import (
+ Annotation,
+ AudioContentBlock,
+ Citation,
+ CodeInterpreterCall,
+ CodeInterpreterOutput,
+ CodeInterpreterResult,
+ ContentBlock,
+ DataContentBlock,
+ FileContentBlock,
+ ImageContentBlock,
+ NonStandardAnnotation,
+ NonStandardContentBlock,
+ PlainTextContentBlock,
+ ReasoningContentBlock,
+ TextContentBlock,
+ VideoContentBlock,
+ WebSearchCall,
+ WebSearchResult,
convert_to_openai_data_block,
convert_to_openai_image_block,
is_data_content_block,
@@ -65,24 +83,42 @@ if TYPE_CHECKING:
__all__ = (
"AIMessage",
"AIMessageChunk",
+ "Annotation",
"AnyMessage",
+ "AudioContentBlock",
"BaseMessage",
"BaseMessageChunk",
"ChatMessage",
"ChatMessageChunk",
+ "Citation",
+ "CodeInterpreterCall",
+ "CodeInterpreterOutput",
+ "CodeInterpreterResult",
+ "ContentBlock",
+ "DataContentBlock",
+ "FileContentBlock",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
+ "ImageContentBlock",
"InvalidToolCall",
"MessageLikeRepresentation",
+ "NonStandardAnnotation",
+ "NonStandardContentBlock",
+ "PlainTextContentBlock",
+ "ReasoningContentBlock",
"RemoveMessage",
"SystemMessage",
"SystemMessageChunk",
+ "TextContentBlock",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
+ "VideoContentBlock",
+ "WebSearchCall",
+ "WebSearchResult",
"_message_from_dict",
"convert_to_messages",
"convert_to_openai_data_block",
@@ -103,25 +139,43 @@ __all__ = (
_dynamic_imports = {
"AIMessage": "ai",
"AIMessageChunk": "ai",
+ "Annotation": "content_blocks",
+ "AudioContentBlock": "content_blocks",
"BaseMessage": "base",
"BaseMessageChunk": "base",
"merge_content": "base",
"message_to_dict": "base",
"messages_to_dict": "base",
+ "Citation": "content_blocks",
+ "ContentBlock": "content_blocks",
"ChatMessage": "chat",
"ChatMessageChunk": "chat",
+ "CodeInterpreterCall": "content_blocks",
+ "CodeInterpreterOutput": "content_blocks",
+ "CodeInterpreterResult": "content_blocks",
+ "DataContentBlock": "content_blocks",
+ "FileContentBlock": "content_blocks",
"FunctionMessage": "function",
"FunctionMessageChunk": "function",
"HumanMessage": "human",
"HumanMessageChunk": "human",
+ "NonStandardAnnotation": "content_blocks",
+ "NonStandardContentBlock": "content_blocks",
+ "PlainTextContentBlock": "content_blocks",
+ "ReasoningContentBlock": "content_blocks",
"RemoveMessage": "modifier",
"SystemMessage": "system",
"SystemMessageChunk": "system",
+ "WebSearchCall": "content_blocks",
+ "WebSearchResult": "content_blocks",
+ "ImageContentBlock": "content_blocks",
"InvalidToolCall": "tool",
+ "TextContentBlock": "content_blocks",
"ToolCall": "tool",
"ToolCallChunk": "tool",
"ToolMessage": "tool",
"ToolMessageChunk": "tool",
+ "VideoContentBlock": "content_blocks",
"AnyMessage": "utils",
"MessageLikeRepresentation": "utils",
"_message_from_dict": "utils",
diff --git a/libs/core/langchain_core/messages/content_blocks.py b/libs/core/langchain_core/messages/content_blocks.py
index 83a66fb123a..a2efd0fef9c 100644
--- a/libs/core/langchain_core/messages/content_blocks.py
+++ b/libs/core/langchain_core/messages/content_blocks.py
@@ -1,110 +1,782 @@
-"""Types for content blocks."""
+"""Standard, multimodal content blocks for Large Language Model I/O.
+
+.. warning::
+ This module is under active development. The API is unstable and subject to
+ change in future releases.
+
+This module provides a standardized data structure for representing inputs to and
+outputs from Large Language Models. The core abstraction is the **Content Block**, a
+``TypedDict`` that can represent a piece of text, an image, a tool call, or other
+structured data.
+
+Data **not yet mapped** to a standard block may be represented using the
+``NonStandardContentBlock``, which allows for provider-specific data to be included
+without losing the benefits of type checking and validation.
+
+Furthermore, provider-specific fields *within* a standard block will be allowed as extra
+keys on the TypedDict per `PEP 728 `__. This allows
+for flexibility in the data structure while maintaining a consistent interface.
+
+**Example using ``extra_items=Any``:**
+
+.. code-block:: python
+ from langchain_core.messages.content_blocks import TextContentBlock
+ from typing import Any
+
+ my_block: TextContentBlock = {
+ "type": "text",
+ "text": "Hello, world!",
+ "extra_field": "This is allowed",
+ "another_field": 42, # Any type is allowed
+ }
+
+ # A type checker that supports PEP 728 would validate the object above.
+ # Accessing the provider-specific key is possible, and its type is 'Any'.
+ block_extra_field = my_block["extra_field"]
+
+.. warning::
+ Type checkers such as MyPy do not yet support `PEP 728 `__,
+ so you may see type errors when using provider-specific fields. These are safe to
+ ignore, as the fields are still validated at runtime.
+
+**Rationale**
+
+Different LLM providers use distinct and incompatible API schemas. This module
+introduces a unified, provider-agnostic format to standardize these interactions. A
+message to or from a model is simply a `list` of `ContentBlock` objects, allowing for
+the natural interleaving of text, images, and other content in a single, ordered
+sequence.
+
+An adapter for a specific provider is responsible for translating this standard list of
+blocks into the format required by its API.
+
+**Key Block Types**
+
+The module defines several types of content blocks, including:
+
+- **``TextContentBlock``**: Standard text.
+- **``ImageContentBlock``**, **``AudioContentBlock``**, **``VideoContentBlock``**: For
+ multimodal data.
+- **``ToolCallContentBlock``**, **``ToolOutputContentBlock``**: For function calling.
+- **``ReasoningContentBlock``**: To capture a model's thought process.
+- **``Citation``**: For annotations that link generated text to a source document.
+
+**Example Usage**
+
+.. code-block:: python
+
+ from langchain_core.messages.content_blocks import TextContentBlock, ImageContentBlock
+
+ multimodal_message: AIMessage = [
+ TextContentBlock(type="text", text="What is shown in this image?"),
+ ImageContentBlock(
+ type="image",
+ url="https://www.langchain.com/images/brand/langchain_logo_text_w_white.png",
+ mime_type="image/png",
+ ),
+ ]
+""" # noqa: E501
import warnings
-from typing import Any, Literal, Union
+from typing import Any, Literal, Optional, Union
-from pydantic import TypeAdapter, ValidationError
-from typing_extensions import NotRequired, TypedDict
+from typing_extensions import NotRequired, TypedDict, get_args, get_origin
+
+# --- Text and annotations ---
-class BaseDataContentBlock(TypedDict, total=False):
- """Base class for data content blocks."""
+class Citation(TypedDict):
+ """Annotation for citing data from a document.
+
+ .. note::
+ ``start/end`` indices refer to the **response text**,
+ not the source text. This means that the indices are relative to the model's
+ response, not the original document (as specified in the ``url``).
+ """
+
+ type: Literal["citation"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ url: NotRequired[str]
+ """URL of the document source."""
+
+ # For future consideration, if needed:
+ # provenance: NotRequired[str]
+ # """Provenance of the document, e.g., "Wikipedia", "arXiv", etc.
+
+ # Included for future compatibility; not currently implemented.
+ # """
+
+ title: NotRequired[str]
+ """Source document title.
+
+ For example, the page title for a web page or the title of a paper.
+ """
+
+ start_index: NotRequired[int]
+ """Start index of the **response text** (``TextContentBlock.text``) for which the
+ annotation applies."""
+
+ end_index: NotRequired[int]
+ """End index of the **response text** (``TextContentBlock.text``) for which the
+ annotation applies."""
+
+ cited_text: NotRequired[str]
+ """Excerpt of source text being cited."""
+
+ # NOTE: not including spans for the raw document text (such as `text_start_index`
+ # and `text_end_index`) as this is not currently supported by any provider. The
+ # thinking is that the `cited_text` should be sufficient for most use cases, and it
+ # is difficult to reliably extract spans from the raw document text across file
+ # formats or encoding schemes.
+
+
+class NonStandardAnnotation(TypedDict):
+ """Provider-specific annotation format."""
+
+ type: Literal["non_standard_annotation"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ value: dict[str, Any]
+ """Provider-specific annotation data."""
+
+
+Annotation = Union[Citation, NonStandardAnnotation]
+
+
+class TextContentBlock(TypedDict):
+ """Content block for text output.
+
+ This typically represents the main text content of a message, such as the response
+ from a language model or the text of a user message.
+ """
+
+ type: Literal["text"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ text: str
+ """Block text."""
+
+ annotations: NotRequired[list[Annotation]]
+ """Citations and other annotations."""
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+
+# --- Tool calls ---
+class ToolCall(TypedDict):
+ """Represents a request to call a tool.
+
+ Example:
+
+ .. code-block:: python
+
+ {
+ "name": "foo",
+ "args": {"a": 1},
+ "id": "123"
+ }
+
+ This represents a request to call the tool named "foo" with arguments {"a": 1}
+ and an identifier of "123".
+ """
+
+ name: str
+ """The name of the tool to be called."""
+ args: dict[str, Any]
+ """The arguments to the tool call."""
+ id: Optional[str]
+ """An identifier associated with the tool call.
+
+ An identifier is needed to associate a tool call request with a tool
+ call result in events when multiple concurrent tool calls are made.
+ """
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+ type: Literal["tool_call"]
+
+
+class InvalidToolCall(TypedDict):
+ """Allowance for errors made by LLM.
+
+ Here we add an `error` key to surface errors made during generation
+ (e.g., invalid JSON arguments.)
+ """
+
+ name: Optional[str]
+ """The name of the tool to be called."""
+ args: Optional[str]
+ """The arguments to the tool call."""
+ id: Optional[str]
+ """An identifier associated with the tool call."""
+ error: Optional[str]
+ """An error message associated with the tool call."""
+ type: Literal["invalid_tool_call"]
+
+
+class ToolCallChunk(TypedDict):
+ """A chunk of a tool call (e.g., as part of a stream).
+
+ When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
+ all string attributes are concatenated. Chunks are only merged if their
+ values of `index` are equal and not None.
+
+ Example:
+
+ .. code-block:: python
+
+ left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
+ right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
+
+ (
+ AIMessageChunk(content="", tool_call_chunks=left_chunks)
+ + AIMessageChunk(content="", tool_call_chunks=right_chunks)
+ ).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]
+ """
+
+ name: Optional[str]
+ """The name of the tool to be called."""
+ args: Optional[str]
+ """The arguments to the tool call."""
+ id: Optional[str]
+ """An identifier associated with the tool call."""
+ index: Optional[int]
+ """The index of the tool call in a sequence."""
+ type: NotRequired[Literal["tool_call_chunk"]]
+
+
+# --- Provider tool calls (built-in tools) ---
+# Note: These are not standard tool calls, but rather provider-specific built-in tools.
+
+
+# Web search
+class WebSearchCall(TypedDict):
+ """Content block for a built-in web search tool call."""
+
+ type: Literal["web_search_call"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ query: NotRequired[str]
+ """The search query used in the web search tool call."""
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+
+class WebSearchResult(TypedDict):
+ """Content block for the result of a built-in web search tool call."""
+
+ type: Literal["web_search_result"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ urls: NotRequired[list[str]]
+ """List of URLs returned by the web search tool call."""
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+
+# Code interpreter
+
+
+# Call
+class CodeInterpreterCall(TypedDict):
+ """Content block for a built-in code interpreter tool call."""
+
+ type: Literal["code_interpreter_call"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ language: NotRequired[str]
+ """The programming language used in the code interpreter tool call."""
+
+ code: NotRequired[str]
+ """The code to be executed by the code interpreter."""
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+
+# Result block is CodeInterpreterResult
+class CodeInterpreterOutput(TypedDict):
+ """Content block for the output of a singular code interpreter tool call.
+
+ Full output of a code interpreter tool call is represented by
+ ``CodeInterpreterResult`` which is a list of these blocks.
+ """
+
+ type: Literal["code_interpreter_output"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ return_code: NotRequired[int]
+ """Return code of the executed code.
+
+ Example: 0 for success, non-zero for failure.
+ """
+
+ stderr: NotRequired[str]
+ """Standard error output of the executed code."""
+
+ stdout: NotRequired[str]
+ """Standard output of the executed code."""
+
+ file_ids: NotRequired[list[str]]
+ """List of file IDs generated by the code interpreter."""
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+
+class CodeInterpreterResult(TypedDict):
+ """Content block for the result of a code interpreter tool call."""
+
+ type: Literal["code_interpreter_result"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ output: list[CodeInterpreterOutput]
+ """List of outputs from the code interpreter tool call."""
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+
+# --- Reasoning ---
+class ReasoningContentBlock(TypedDict):
+ """Content block for reasoning output."""
+
+ type: Literal["reasoning"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ reasoning: NotRequired[str]
+ """Reasoning text.
+
+ Either the thought summary or the raw reasoning text itself. This is often parsed
+ from ```` tags in the model's response.
+ """
+
+ thought_signature: NotRequired[str]
+ """Opaque state handle representation of the model's internal thought process.
+
+ Maintains the context of the model's thinking across multiple interactions
+ (e.g. multi-turn conversations) since many APIs are stateless.
+
+ Not to be used to verify authenticity or integrity of the response (`'signature'`).
+
+ Examples:
+ - https://ai.google.dev/gemini-api/docs/thinking#signatures
+ """
+
+ signature: NotRequired[str]
+ """Signature of the reasoning content block used to verify **authenticity**.
+
+ Prevents from modifying or fabricating the model's reasoning process.
+
+ Examples:
+ - https://docs.anthropic.com/en/docs/build-with-claude/context-windows#the-context-window-with-extended-thinking-and-tool-use
+ """
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+
+# --- Multi-modal ---
+
+
+# Note: `title` and `context` are fields that could be used to provide additional
+# information about the file, such as a description or summary of its content.
+# E.g. with Claude, you can provide a context for a file which is passed to the model.
+class ImageContentBlock(TypedDict):
+ """Content block for image data."""
+
+ type: Literal["image"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ file_id: NotRequired[str]
+ """ID of the image file, e.g., from a file storage system."""
mime_type: NotRequired[str]
- """MIME type of the content block (if needed)."""
+ """MIME type of the image. Required for base64.
+ `Examples from IANA `__
+ """
-class URLContentBlock(BaseDataContentBlock):
- """Content block for data from a URL."""
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
- type: Literal["image", "audio", "file"]
- """Type of the content block."""
- source_type: Literal["url"]
- """Source type (url)."""
- url: str
- """URL for data."""
+ url: NotRequired[str]
+ """URL of the image."""
-
-class Base64ContentBlock(BaseDataContentBlock):
- """Content block for inline data from a base64 string."""
-
- type: Literal["image", "audio", "file"]
- """Type of the content block."""
- source_type: Literal["base64"]
- """Source type (base64)."""
- data: str
+ base64: NotRequired[str]
"""Data as a base64 string."""
+ # title: NotRequired[str]
+ # """Title of the image."""
-class PlainTextContentBlock(BaseDataContentBlock):
- """Content block for plain text data (e.g., from a document)."""
+ # context: NotRequired[str]
+ # """Context for the image, e.g., a description or summary of the image's content.""" # noqa: E501
+
+
+class VideoContentBlock(TypedDict):
+ """Content block for video data."""
+
+ type: Literal["video"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ file_id: NotRequired[str]
+ """ID of the video file, e.g., from a file storage system."""
+
+ mime_type: NotRequired[str]
+ """MIME type of the video. Required for base64.
+
+ `Examples from IANA `__
+ """
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+ url: NotRequired[str]
+ """URL of the video."""
+
+ base64: NotRequired[str]
+ """Data as a base64 string."""
+
+ # title: NotRequired[str]
+ # """Title of the video."""
+
+ # context: NotRequired[str]
+ # """Context for the video, e.g., a description or summary of the video's content.""" # noqa: E501
+
+
+class AudioContentBlock(TypedDict):
+ """Content block for audio data."""
+
+ type: Literal["audio"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ file_id: NotRequired[str]
+ """ID of the audio file, e.g., from a file storage system."""
+
+ mime_type: NotRequired[str]
+ """MIME type of the audio. Required for base64.
+
+ `Examples from IANA `__
+ """
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+ url: NotRequired[str]
+ """URL of the audio."""
+
+ base64: NotRequired[str]
+ """Data as a base64 string."""
+
+ # title: NotRequired[str]
+ # """Title of the audio."""
+
+ # context: NotRequired[str]
+ # """Context for the audio, e.g., a description or summary of the audio's content.""" # noqa: E501
+
+
+class PlainTextContentBlock(TypedDict):
+ """Content block for plaintext data (e.g., from a document).
+
+ .. note::
+ Title and context are optional fields that may be passed to the model. See
+ Anthropic `example `__.
+ """
+
+ type: Literal["text-plain"]
+ """Type of the content block."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ file_id: NotRequired[str]
+ """ID of the plaintext file, e.g., from a file storage system."""
+
+ mime_type: Literal["text/plain"]
+ """MIME type of the file. Required for base64."""
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+ url: NotRequired[str]
+ """URL of the plaintext."""
+
+ base64: NotRequired[str]
+ """Data as a base64 string."""
+
+ text: NotRequired[str]
+ """Plaintext content. This is optional if the data is provided as base64."""
+
+ title: NotRequired[str]
+ """Title of the text data, e.g., the title of a document."""
+
+ context: NotRequired[str]
+ """Context for the text, e.g., a description or summary of the text's content."""
+
+
+class FileContentBlock(TypedDict):
+ """Content block for file data.
+
+ This block is intended for files that are not images, audio, or plaintext. For
+ example, it can be used for PDFs, Word documents, etc.
+
+ If the file is an image, audio, or plaintext, you should use the corresponding
+ content block type (e.g., ``ImageContentBlock``, ``AudioContentBlock``,
+ ``PlainTextContentBlock``).
+ """
type: Literal["file"]
"""Type of the content block."""
- source_type: Literal["text"]
- """Source type (text)."""
- text: str
- """Text data."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ file_id: NotRequired[str]
+ """ID of the file, e.g., from a file storage system."""
+
+ mime_type: NotRequired[str]
+ """MIME type of the file. Required for base64.
+
+ `Examples from IANA `__
+ """
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+
+ url: NotRequired[str]
+ """URL of the file."""
+
+ base64: NotRequired[str]
+ """Data as a base64 string."""
+
+ # title: NotRequired[str]
+ # """Title of the file, e.g., the name of a document or file."""
+
+ # context: NotRequired[str]
+ # """Context for the file, e.g., a description or summary of the file's content."""
-class IDContentBlock(TypedDict):
- """Content block for data specified by an identifier."""
+# Future modalities to consider:
+# - 3D models
+# - Tabular data
- type: Literal["image", "audio", "file"]
+
+# Non-standard
+class NonStandardContentBlock(TypedDict):
+ """Content block provider-specific data.
+
+ This block contains data for which there is not yet a standard type.
+
+ The purpose of this block should be to simply hold a provider-specific payload.
+ If a provider's non-standard output includes reasoning and tool calls, it should be
+ the adapter's job to parse that payload and emit the corresponding standard
+ ReasoningContentBlock and ToolCallContentBlocks.
+ """
+
+ type: Literal["non_standard"]
"""Type of the content block."""
- source_type: Literal["id"]
- """Source type (id)."""
- id: str
- """Identifier for data source."""
+
+ id: NotRequired[str]
+ """Content block identifier. Either:
+
+ - Generated by the provider (e.g., OpenAI's file ID)
+ - Generated by LangChain upon creation (as ``UUID4``)
+ """
+
+ value: dict[str, Any]
+ """Provider-specific data."""
+
+ index: NotRequired[int]
+ """Index of block in aggregate response. Used during streaming."""
+# --- Aliases ---
DataContentBlock = Union[
- URLContentBlock,
- Base64ContentBlock,
+ ImageContentBlock,
+ VideoContentBlock,
+ AudioContentBlock,
PlainTextContentBlock,
- IDContentBlock,
+ FileContentBlock,
]
-_DataContentBlockAdapter: TypeAdapter[DataContentBlock] = TypeAdapter(DataContentBlock)
+ToolContentBlock = Union[
+ ToolCall,
+ CodeInterpreterCall,
+ CodeInterpreterOutput,
+ CodeInterpreterResult,
+ WebSearchCall,
+ WebSearchResult,
+]
+
+ContentBlock = Union[
+ TextContentBlock,
+ ToolCall,
+ ReasoningContentBlock,
+ NonStandardContentBlock,
+ DataContentBlock,
+ ToolContentBlock,
+]
-def is_data_content_block(
- content_block: dict,
-) -> bool:
+def _extract_typedict_type_values(union_type: Any) -> set[str]:
+ """Extract the values of the 'type' field from a TypedDict union type."""
+ result: set[str] = set()
+ for value in get_args(union_type):
+ annotation = value.__annotations__["type"]
+ if get_origin(annotation) is Literal:
+ result.update(get_args(annotation))
+ else:
+ msg = f"{value} 'type' is not a Literal"
+ raise ValueError(msg)
+ return result
+
+
+KNOWN_BLOCK_TYPES = {
+ bt for bt in get_args(ContentBlock) for bt in get_args(bt.__annotations__["type"])
+}
+
+
+def is_data_content_block(block: dict) -> bool:
"""Check if the content block is a standard data content block.
Args:
- content_block: The content block to check.
+ block: The content block to check.
Returns:
True if the content block is a data content block, False otherwise.
"""
- try:
- _ = _DataContentBlockAdapter.validate_python(content_block)
- except ValidationError:
- return False
- else:
- return True
+ return block.get("type") in (
+ "audio",
+ "image",
+ "video",
+ "file",
+ "text-plain",
+ ) and any(
+ key in block
+ for key in (
+ "url",
+ "base64",
+ "file_id",
+ "text",
+ "source_type", # backwards compatibility
+ )
+ )
-def convert_to_openai_image_block(content_block: dict[str, Any]) -> dict:
+def convert_to_openai_image_block(block: dict[str, Any]) -> dict:
"""Convert image content block to format expected by OpenAI Chat Completions API."""
- if content_block["source_type"] == "url":
+ if "url" in block:
return {
"type": "image_url",
"image_url": {
- "url": content_block["url"],
+ "url": block["url"],
},
}
- if content_block["source_type"] == "base64":
- if "mime_type" not in content_block:
+ if "base64" in block or block.get("source_type") == "base64":
+ if "mime_type" not in block:
error_message = "mime_type key is required for base64 data."
raise ValueError(error_message)
- mime_type = content_block["mime_type"]
+ mime_type = block["mime_type"]
+ base64_data = block["data"] if "data" in block else block["base64"]
return {
"type": "image_url",
"image_url": {
- "url": f"data:{mime_type};base64,{content_block['data']}",
+ "url": f"data:{mime_type};base64,{base64_data}",
},
}
error_message = "Unsupported source type. Only 'url' and 'base64' are supported."
@@ -117,8 +789,9 @@ def convert_to_openai_data_block(block: dict) -> dict:
formatted_block = convert_to_openai_image_block(block)
elif block["type"] == "file":
- if block["source_type"] == "base64":
- file = {"file_data": f"data:{block['mime_type']};base64,{block['data']}"}
+ if "base64" in block or block.get("source_type") == "base64":
+ base64_data = block["data"] if "source_type" in block else block["base64"]
+ file = {"file_data": f"data:{block['mime_type']};base64,{base64_data}"}
if filename := block.get("filename"):
file["filename"] = filename
elif (metadata := block.get("metadata")) and ("filename" in metadata):
@@ -126,27 +799,28 @@ def convert_to_openai_data_block(block: dict) -> dict:
else:
warnings.warn(
"OpenAI may require a filename for file inputs. Specify a filename "
- "in the content block: {'type': 'file', 'source_type': 'base64', "
- "'mime_type': 'application/pdf', 'data': '...', "
- "'filename': 'my-pdf'}",
+ "in the content block: {'type': 'file', 'mime_type': "
+ "'application/pdf', 'base64': '...', 'filename': 'my-pdf'}",
stacklevel=1,
)
formatted_block = {"type": "file", "file": file}
- elif block["source_type"] == "id":
- formatted_block = {"type": "file", "file": {"file_id": block["id"]}}
+ elif "file_id" in block or block.get("source_type") == "id":
+ file_id = block["id"] if "source_type" in block else block["file_id"]
+ formatted_block = {"type": "file", "file": {"file_id": file_id}}
else:
- error_msg = "source_type base64 or id is required for file blocks."
+ error_msg = "Keys base64 or file_id required for file blocks."
raise ValueError(error_msg)
elif block["type"] == "audio":
- if block["source_type"] == "base64":
+ if "base64" in block or block.get("source_type") == "base64":
+ base64_data = block["data"] if "source_type" in block else block["base64"]
audio_format = block["mime_type"].split("/")[-1]
formatted_block = {
"type": "input_audio",
- "input_audio": {"data": block["data"], "format": audio_format},
+ "input_audio": {"data": base64_data, "format": audio_format},
}
else:
- error_msg = "source_type base64 is required for audio blocks."
+ error_msg = "Key base64 is required for audio blocks."
raise ValueError(error_msg)
else:
error_msg = f"Block of type {block['type']} is not supported."
diff --git a/libs/core/langchain_core/messages/modifier.py b/libs/core/langchain_core/messages/modifier.py
index 08b7e79b69c..5f1602a4908 100644
--- a/libs/core/langchain_core/messages/modifier.py
+++ b/libs/core/langchain_core/messages/modifier.py
@@ -13,7 +13,7 @@ class RemoveMessage(BaseMessage):
def __init__(
self,
- id: str, # noqa: A002
+ id: str,
**kwargs: Any,
) -> None:
"""Create a RemoveMessage.
diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py
index 1f8a519a7dc..181c80443d5 100644
--- a/libs/core/langchain_core/messages/tool.py
+++ b/libs/core/langchain_core/messages/tool.py
@@ -5,9 +5,12 @@ from typing import Any, Literal, Optional, Union
from uuid import UUID
from pydantic import Field, model_validator
-from typing_extensions import NotRequired, TypedDict, override
+from typing_extensions import override
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
+from langchain_core.messages.content_blocks import InvalidToolCall as InvalidToolCall
+from langchain_core.messages.content_blocks import ToolCall as ToolCall
+from langchain_core.messages.content_blocks import ToolCallChunk as ToolCallChunk
from langchain_core.utils._merge import merge_dicts, merge_obj
@@ -177,42 +180,11 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
return super().__add__(other)
-class ToolCall(TypedDict):
- """Represents a request to call a tool.
-
- Example:
-
- .. code-block:: python
-
- {
- "name": "foo",
- "args": {"a": 1},
- "id": "123"
- }
-
- This represents a request to call the tool named "foo" with arguments {"a": 1}
- and an identifier of "123".
-
- """
-
- name: str
- """The name of the tool to be called."""
- args: dict[str, Any]
- """The arguments to the tool call."""
- id: Optional[str]
- """An identifier associated with the tool call.
-
- An identifier is needed to associate a tool call request with a tool
- call result in events when multiple concurrent tool calls are made.
- """
- type: NotRequired[Literal["tool_call"]]
-
-
def tool_call(
*,
name: str,
args: dict[str, Any],
- id: Optional[str], # noqa: A002
+ id: Optional[str],
) -> ToolCall:
"""Create a tool call.
@@ -224,43 +196,11 @@ def tool_call(
return ToolCall(name=name, args=args, id=id, type="tool_call")
-class ToolCallChunk(TypedDict):
- """A chunk of a tool call (e.g., as part of a stream).
-
- When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
- all string attributes are concatenated. Chunks are only merged if their
- values of `index` are equal and not None.
-
- Example:
-
- .. code-block:: python
-
- left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
- right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
-
- (
- AIMessageChunk(content="", tool_call_chunks=left_chunks)
- + AIMessageChunk(content="", tool_call_chunks=right_chunks)
- ).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]
-
- """
-
- name: Optional[str]
- """The name of the tool to be called."""
- args: Optional[str]
- """The arguments to the tool call."""
- id: Optional[str]
- """An identifier associated with the tool call."""
- index: Optional[int]
- """The index of the tool call in a sequence."""
- type: NotRequired[Literal["tool_call_chunk"]]
-
-
def tool_call_chunk(
*,
name: Optional[str] = None,
args: Optional[str] = None,
- id: Optional[str] = None, # noqa: A002
+ id: Optional[str] = None,
index: Optional[int] = None,
) -> ToolCallChunk:
"""Create a tool call chunk.
@@ -276,29 +216,11 @@ def tool_call_chunk(
)
-class InvalidToolCall(TypedDict):
- """Allowance for errors made by LLM.
-
- Here we add an `error` key to surface errors made during generation
- (e.g., invalid JSON arguments.)
- """
-
- name: Optional[str]
- """The name of the tool to be called."""
- args: Optional[str]
- """The arguments to the tool call."""
- id: Optional[str]
- """An identifier associated with the tool call."""
- error: Optional[str]
- """An error message associated with the tool call."""
- type: NotRequired[Literal["invalid_tool_call"]]
-
-
def invalid_tool_call(
*,
name: Optional[str] = None,
args: Optional[str] = None,
- id: Optional[str] = None, # noqa: A002
+ id: Optional[str] = None,
error: Optional[str] = None,
) -> InvalidToolCall:
"""Create an invalid tool call.
diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py
index 11c044eb438..ed9e8d39745 100644
--- a/libs/core/langchain_core/messages/utils.py
+++ b/libs/core/langchain_core/messages/utils.py
@@ -40,6 +40,12 @@ from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.modifier import RemoveMessage
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
+from langchain_core.messages.v1 import AIMessage as AIMessageV1
+from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
+from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
+from langchain_core.messages.v1 import MessageV1, MessageV1Types
+from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
+from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
if TYPE_CHECKING:
from langchain_text_splitters import TextSplitter
@@ -203,7 +209,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
MessageLikeRepresentation = Union[
- BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
+ BaseMessage, list[str], tuple[str, str], str, dict[str, Any], MessageV1
]
@@ -213,7 +219,7 @@ def _create_message_from_message_type(
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
tool_calls: Optional[list[dict[str, Any]]] = None,
- id: Optional[str] = None, # noqa: A002
+ id: Optional[str] = None,
**additional_kwargs: Any,
) -> BaseMessage:
"""Create a message from a message type and content string.
@@ -294,6 +300,128 @@ def _create_message_from_message_type(
return message
+def _create_message_from_message_type_v1(
+ message_type: str,
+ content: str,
+ name: Optional[str] = None,
+ tool_call_id: Optional[str] = None,
+ tool_calls: Optional[list[dict[str, Any]]] = None,
+ id: Optional[str] = None,
+ **kwargs: Any,
+) -> MessageV1:
+ """Create a message from a message type and content string.
+
+ Args:
+ message_type: (str) the type of the message (e.g., "human", "ai", etc.).
+ content: (str) the content string.
+ name: (str) the name of the message. Default is None.
+ tool_call_id: (str) the tool call id. Default is None.
+ tool_calls: (list[dict[str, Any]]) the tool calls. Default is None.
+ id: (str) the id of the message. Default is None.
+ kwargs: (dict[str, Any]) additional keyword arguments.
+
+ Returns:
+ a message of the appropriate type.
+
+ Raises:
+ ValueError: if the message type is not one of "human", "user", "ai",
+ "assistant", "tool", "system", or "developer".
+ """
+ kwargs: dict[str, Any] = {}
+ if name is not None:
+ kwargs["name"] = name
+ if tool_call_id is not None:
+ kwargs["tool_call_id"] = tool_call_id
+ if kwargs and (response_metadata := kwargs.pop("response_metadata", None)):
+ kwargs["response_metadata"] = response_metadata
+ if id is not None:
+ kwargs["id"] = id
+ if tool_calls is not None:
+ kwargs["tool_calls"] = []
+ for tool_call in tool_calls:
+ # Convert OpenAI-format tool call to LangChain format.
+ if "function" in tool_call:
+ args = tool_call["function"]["arguments"]
+ if isinstance(args, str):
+ args = json.loads(args, strict=False)
+ kwargs["tool_calls"].append(
+ {
+ "name": tool_call["function"]["name"],
+ "args": args,
+ "id": tool_call["id"],
+ "type": "tool_call",
+ }
+ )
+ else:
+ kwargs["tool_calls"].append(tool_call)
+ if message_type in {"human", "user"}:
+ message = HumanMessageV1(content=content, **kwargs)
+ elif message_type in {"ai", "assistant"}:
+ message = AIMessageV1(content=content, **kwargs)
+ elif message_type in {"system", "developer"}:
+ if message_type == "developer":
+ kwargs["custom_role"] = "developer"
+ message = SystemMessageV1(content=content, **kwargs)
+ elif message_type == "tool":
+ artifact = kwargs.pop("artifact", None)
+ message = ToolMessageV1(content=content, artifact=artifact, **kwargs)
+ else:
+ msg = (
+ f"Unexpected message type: '{message_type}'. Use one of 'human',"
+ f" 'user', 'ai', 'assistant', 'function', 'tool', 'system', or 'developer'."
+ )
+ msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
+ raise ValueError(msg)
+ return message
+
+
+def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
+ """Compatibility layer to convert v1 messages to current messages.
+
+ Args:
+ message: MessageV1 instance to convert.
+
+ Returns:
+ BaseMessage: Converted message instance.
+ """
+ # type ignores here are because AIMessageV1.content is a list of dicts.
+ # AIMessageV0.content expects str or list[str | dict].
+ if isinstance(message, AIMessageV1):
+ return AIMessage(
+ content=message.content, # type: ignore[arg-type]
+ id=message.id,
+ name=message.name,
+ tool_calls=message.tool_calls,
+ response_metadata=message.response_metadata,
+ )
+ if isinstance(message, AIMessageChunkV1):
+ return AIMessageChunk(
+ content=message.content, # type: ignore[arg-type]
+ id=message.id,
+ name=message.name,
+ tool_call_chunks=message.tool_call_chunks,
+ response_metadata=message.response_metadata,
+ )
+ if isinstance(message, HumanMessageV1):
+ return HumanMessage(
+ content=message.content, # type: ignore[arg-type]
+ id=message.id,
+ name=message.name,
+ )
+ if isinstance(message, SystemMessageV1):
+ return SystemMessage(
+ content=message.content, # type: ignore[arg-type]
+ id=message.id,
+ )
+ if isinstance(message, ToolMessageV1):
+ return ToolMessage(
+ content=message.content, # type: ignore[arg-type]
+ id=message.id,
+ )
+ message = f"Unsupported message type: {type(message)}"
+ raise NotImplementedError(message)
+
+
def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
"""Instantiate a message from a variety of message formats.
@@ -341,6 +469,63 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
message_ = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
+ elif isinstance(message, MessageV1Types):
+ message_ = _convert_from_v1_message(message)
+ else:
+ msg = f"Unsupported message type: {type(message)}"
+ msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
+ raise NotImplementedError(msg)
+
+ return message_
+
+
+def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
+ """Instantiate a message from a variety of message formats.
+
+ The message format can be one of the following:
+
+ - BaseMessagePromptTemplate
+ - BaseMessage
+ - 2-tuple of (role string, template); e.g., ("human", "{user_input}")
+ - dict: a message dict with role and content keys
+ - string: shorthand for ("human", template); e.g., "{user_input}"
+
+ Args:
+ message: a representation of a message in one of the supported formats.
+
+ Returns:
+ an instance of a message or a message template.
+
+ Raises:
+ NotImplementedError: if the message type is not supported.
+ ValueError: if the message dict does not contain the required keys.
+ """
+ if isinstance(message, MessageV1Types):
+ message_ = message
+ elif isinstance(message, str):
+ message_ = _create_message_from_message_type_v1("human", message)
+ elif isinstance(message, Sequence) and len(message) == 2:
+ # mypy doesn't realise this can't be a string given the previous branch
+ message_type_str, template = message # type: ignore[misc]
+ message_ = _create_message_from_message_type_v1(message_type_str, template)
+ elif isinstance(message, dict):
+ msg_kwargs = message.copy()
+ try:
+ try:
+ msg_type = msg_kwargs.pop("role")
+ except KeyError:
+ msg_type = msg_kwargs.pop("type")
+ # None msg content is not allowed
+ msg_content = msg_kwargs.pop("content") or ""
+ except KeyError as e:
+ msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
+ msg = create_message(
+ message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
+ )
+ raise ValueError(msg) from e
+ message_ = _create_message_from_message_type_v1(
+ msg_type, msg_content, **msg_kwargs
+ )
else:
msg = f"Unsupported message type: {type(message)}"
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
@@ -368,6 +553,25 @@ def convert_to_messages(
return [_convert_to_message(m) for m in messages]
+def convert_to_messages_v1(
+ messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
+) -> list[MessageV1]:
+ """Convert a sequence of messages to a list of messages.
+
+ Args:
+ messages: Sequence of messages to convert.
+
+ Returns:
+ list of messages (BaseMessages).
+ """
+ # Import here to avoid circular imports
+ from langchain_core.prompt_values import PromptValue
+
+ if isinstance(messages, PromptValue):
+ return messages.to_messages(output_version="v1")
+ return [_convert_to_message_v1(m) for m in messages]
+
+
def _runnable_support(func: Callable) -> Callable:
@overload
def wrapped(
@@ -1007,10 +1211,11 @@ def convert_to_openai_messages(
oai_messages: list = []
- if is_single := isinstance(messages, (BaseMessage, dict, str)):
+ if is_single := isinstance(messages, (BaseMessage, dict, str, MessageV1Types)):
messages = [messages]
- messages = convert_to_messages(messages)
+ # TODO: resolve type ignore here
+ messages = convert_to_messages(messages) # type: ignore[arg-type]
for i, message in enumerate(messages):
oai_msg: dict = {"role": _get_message_openai_role(message)}
diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py
new file mode 100644
index 00000000000..9ff2eaed4ab
--- /dev/null
+++ b/libs/core/langchain_core/messages/v1.py
@@ -0,0 +1,568 @@
+"""LangChain 1.0 message format."""
+
+import json
+import uuid
+from dataclasses import dataclass, field
+from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
+
+import langchain_core.messages.content_blocks as types
+from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
+from langchain_core.messages.base import merge_content
+from langchain_core.messages.tool import (
+ ToolCallChunk,
+)
+from langchain_core.messages.tool import (
+ invalid_tool_call as create_invalid_tool_call,
+)
+from langchain_core.messages.tool import tool_call as create_tool_call
+from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
+from langchain_core.utils._merge import merge_dicts, merge_lists
+from langchain_core.utils.json import parse_partial_json
+
+
+def _ensure_id(id_val: Optional[str]) -> str:
+ """Ensure the ID is a valid string, generating a new UUID if not provided.
+
+ Args:
+ id_val: Optional string ID value to validate.
+
+ Returns:
+ A valid string ID, either the provided value or a new UUID.
+ """
+ return id_val or str(uuid.uuid4())
+
+
+class Provider(TypedDict):
+ """Information about the provider that generated the message.
+
+ Contains metadata about the AI provider and model used to generate content.
+
+ Attributes:
+ name: Name and version of the provider that created the content block.
+ model_name: Name of the model that generated the content block.
+ """
+
+ name: str
+ """Name and version of the provider that created the content block."""
+ model_name: str
+ """Name of the model that generated the content block."""
+
+
+@dataclass
+class AIMessage:
+ """A message generated by an AI assistant.
+
+ Represents a response from an AI model, including text content, tool calls,
+ and metadata about the generation process.
+
+ Attributes:
+ id: Unique identifier for the message.
+ type: Message type identifier, always "ai".
+ name: Optional human-readable name for the message.
+ lc_version: Encoding version for the message.
+ content: List of content blocks containing the message data.
+ tool_calls: Optional list of tool calls made by the AI.
+ invalid_tool_calls: Optional list of tool calls that failed validation.
+ usage: Optional dictionary containing usage statistics.
+ """
+
+ type: Literal["ai"] = "ai"
+
+ name: Optional[str] = None
+ """An optional name for the message.
+
+ This can be used to provide a human-readable name for the message.
+
+ Usage of this field is optional, and whether it's used or not is up to the
+ model implementation.
+ """
+
+ id: Optional[str] = None
+ """Unique identifier for the message.
+
+ If the provider assigns a meaningful ID, it should be used here.
+ """
+
+ lc_version: str = "v1"
+ """Encoding version for the message."""
+
+ content: list[types.ContentBlock] = field(default_factory=list)
+
+ usage_metadata: Optional[UsageMetadata] = None
+ """If provided, usage metadata for a message, such as token counts."""
+
+ response_metadata: dict = field(default_factory=dict)
+ """Metadata about the response.
+
+ This field should include non-standard data returned by the provider, such as
+ response headers, service tiers, or log probabilities.
+ """
+
+ def __init__(
+ self,
+ content: Union[str, list[types.ContentBlock]],
+ id: Optional[str] = None,
+ name: Optional[str] = None,
+ lc_version: str = "v1",
+ response_metadata: Optional[dict] = None,
+ usage_metadata: Optional[UsageMetadata] = None,
+ ):
+ """Initialize an AI message.
+
+ Args:
+ content: Message content as string or list of content blocks.
+ id: Optional unique identifier for the message.
+ name: Optional human-readable name for the message.
+ lc_version: Encoding version for the message.
+ response_metadata: Optional metadata about the response.
+ usage_metadata: Optional metadata about token usage.
+ """
+ if isinstance(content, str):
+ self.content = [{"type": "text", "text": content}]
+ else:
+ self.content = content
+
+ self.id = id
+ self.name = name
+ self.lc_version = lc_version
+ self.usage_metadata = usage_metadata
+ if response_metadata is None:
+ self.response_metadata = {}
+ else:
+ self.response_metadata = response_metadata
+
+ self._tool_calls: list[types.ToolCall] = []
+ self._invalid_tool_calls: list[types.InvalidToolCall] = []
+
+ @property
+ def text(self) -> Optional[str]:
+ """Extract all text content from the AI message as a string."""
+ text_blocks = [block for block in self.content if block["type"] == "text"]
+ if text_blocks:
+ return "".join(block["text"] for block in text_blocks)
+ return None
+
+ @property
+ def tool_calls(self) -> list[types.ToolCall]: # update once we fix branch
+ """Get the tool calls made by the AI."""
+ if self._tool_calls:
+ return self._tool_calls
+ tool_calls = [block for block in self.content if block["type"] == "tool_call"]
+ if tool_calls:
+ self._tool_calls = tool_calls
+ return self._tool_calls
+
+ @tool_calls.setter
+ def tool_calls(self, value: list[types.ToolCall]) -> None:
+ """Set the tool calls for the AI message."""
+ self._tool_calls = value
+
+
+@dataclass
+class AIMessageChunk:
+ """A partial chunk of an AI message during streaming.
+
+ Represents a portion of an AI response that is delivered incrementally
+ during streaming generation. Contains partial content and metadata.
+
+ Attributes:
+ id: Unique identifier for the message chunk.
+ type: Message type identifier, always "ai_chunk".
+ name: Optional human-readable name for the message.
+ content: List of content blocks containing partial message data.
+ tool_call_chunks: Optional list of partial tool call data.
+ usage_metadata: Optional metadata about token usage and costs.
+ """
+
+ type: Literal["ai_chunk"] = "ai_chunk"
+
+ name: Optional[str] = None
+ """An optional name for the message.
+
+ This can be used to provide a human-readable name for the message.
+
+ Usage of this field is optional, and whether it's used or not is up to the
+ model implementation.
+ """
+
+ id: Optional[str] = None
+ """Unique identifier for the message chunk.
+
+ If the provider assigns a meaningful ID, it should be used here.
+ """
+
+ lc_version: str = "v1"
+ """Encoding version for the message."""
+
+ content: list[types.ContentBlock] = field(init=False)
+
+ usage_metadata: Optional[UsageMetadata] = None
+ """If provided, usage metadata for a message chunk, such as token counts.
+
+ These data represent incremental usage statistics, as opposed to a running total.
+ """
+
+ response_metadata: dict = field(init=False)
+ """Metadata about the response chunk.
+
+ This field should include non-standard data returned by the provider, such as
+ response headers, service tiers, or log probabilities.
+ """
+
+ tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
+
+ def __init__(
+ self,
+ content: Union[str, list[types.ContentBlock]],
+ id: Optional[str] = None,
+ name: Optional[str] = None,
+ lc_version: str = "v1",
+ response_metadata: Optional[dict] = None,
+ usage_metadata: Optional[UsageMetadata] = None,
+ tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
+ ):
+ """Initialize an AI message.
+
+ Args:
+ content: Message content as string or list of content blocks.
+ id: Optional unique identifier for the message.
+ name: Optional human-readable name for the message.
+ lc_version: Encoding version for the message.
+ response_metadata: Optional metadata about the response.
+ usage_metadata: Optional metadata about token usage.
+ tool_call_chunks: Optional list of partial tool call data.
+ """
+ if isinstance(content, str):
+ self.content = [{"type": "text", "text": content, "index": 0}]
+ else:
+ self.content = content
+
+ self.id = id
+ self.name = name
+ self.lc_version = lc_version
+ self.usage_metadata = usage_metadata
+ if response_metadata is None:
+ self.response_metadata = {}
+ else:
+ self.response_metadata = response_metadata
+ if tool_call_chunks is None:
+ self.tool_call_chunks: list[types.ToolCallChunk] = []
+ else:
+ self.tool_call_chunks = tool_call_chunks
+
+ self._tool_calls: list[types.ToolCall] = []
+ self._invalid_tool_calls: list[types.InvalidToolCall] = []
+ self._init_tool_calls()
+
+ def _init_tool_calls(self) -> None:
+ """Initialize tool calls from tool call chunks.
+
+ Args:
+ values: The values to validate.
+
+ Raises:
+ ValueError: If the tool call chunks are malformed.
+ """
+ self._tool_calls = []
+ self._invalid_tool_calls = []
+ if not self.tool_call_chunks:
+ if self._tool_calls:
+ self.tool_call_chunks = [
+ create_tool_call_chunk(
+ name=tc["name"],
+ args=json.dumps(tc["args"]),
+ id=tc["id"],
+ index=None,
+ )
+ for tc in self._tool_calls
+ ]
+ if self._invalid_tool_calls:
+ tool_call_chunks = self.tool_call_chunks
+ tool_call_chunks.extend(
+ [
+ create_tool_call_chunk(
+ name=tc["name"], args=tc["args"], id=tc["id"], index=None
+ )
+ for tc in self._invalid_tool_calls
+ ]
+ )
+ self.tool_call_chunks = tool_call_chunks
+
+ tool_calls = []
+ invalid_tool_calls = []
+
+ def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
+ invalid_tool_calls.append(
+ create_invalid_tool_call(
+ name=chunk["name"],
+ args=chunk["args"],
+ id=chunk["id"],
+ error=None,
+ )
+ )
+
+ for chunk in self.tool_call_chunks:
+ try:
+ args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
+ if isinstance(args_, dict):
+ tool_calls.append(
+ create_tool_call(
+ name=chunk["name"] or "",
+ args=args_,
+ id=chunk["id"],
+ )
+ )
+ else:
+ add_chunk_to_invalid_tool_calls(chunk)
+ except Exception:
+ add_chunk_to_invalid_tool_calls(chunk)
+ self._tool_calls = tool_calls
+ self._invalid_tool_calls = invalid_tool_calls
+
+ @property
+ def text(self) -> Optional[str]:
+ """Extract all text content from the AI message as a string."""
+ text_blocks = [block for block in self.content if block["type"] == "text"]
+ if text_blocks:
+ return "".join(block["text"] for block in text_blocks)
+ return None
+
+ @property
+ def reasoning(self) -> Optional[str]:
+ """Extract all reasoning text from the AI message as a string."""
+ text_blocks = [block for block in self.content if block["type"] == "reasoning"]
+ if text_blocks:
+ return "".join(block["reasoning"] for block in text_blocks)
+ return None
+
+ @property
+ def tool_calls(self) -> list[types.ToolCall]:
+ """Get the tool calls made by the AI."""
+ if self._tool_calls:
+ return self._tool_calls
+ tool_calls = [block for block in self.content if block["type"] == "tool_call"]
+ if tool_calls:
+ self._tool_calls = tool_calls
+ return self._tool_calls
+
+ @tool_calls.setter
+ def tool_calls(self, value: list[types.ToolCall]) -> None:
+ """Set the tool calls for the AI message."""
+ self._tool_calls = value
+
+ def __add__(self, other: Any) -> "AIMessageChunk":
+ """Add AIMessageChunk to this one."""
+ if isinstance(other, AIMessageChunk):
+ return add_ai_message_chunks(self, other)
+ if isinstance(other, (list, tuple)) and all(
+ isinstance(o, AIMessageChunk) for o in other
+ ):
+ return add_ai_message_chunks(self, *other)
+ error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
+ raise NotImplementedError(error_msg)
+
+
+def add_ai_message_chunks(
+ left: AIMessageChunk, *others: AIMessageChunk
+) -> AIMessageChunk:
+ """Add multiple AIMessageChunks together."""
+ if not others:
+ return left
+ content = merge_content(
+ cast("list[str | dict[Any, Any]]", left.content),
+ *(cast("list[str | dict[Any, Any]]", o.content) for o in others),
+ )
+ response_metadata = merge_dicts(
+ left.response_metadata, *(o.response_metadata for o in others)
+ )
+
+ # Merge tool call chunks
+ if raw_tool_calls := merge_lists(
+ left.tool_call_chunks, *(o.tool_call_chunks for o in others)
+ ):
+ tool_call_chunks = [
+ create_tool_call_chunk(
+ name=rtc.get("name"),
+ args=rtc.get("args"),
+ index=rtc.get("index"),
+ id=rtc.get("id"),
+ )
+ for rtc in raw_tool_calls
+ ]
+ else:
+ tool_call_chunks = []
+
+ # Token usage
+ if left.usage_metadata or any(o.usage_metadata is not None for o in others):
+ usage_metadata: Optional[UsageMetadata] = left.usage_metadata
+ for other in others:
+ usage_metadata = add_usage(usage_metadata, other.usage_metadata)
+ else:
+ usage_metadata = None
+
+ chunk_id = None
+ candidates = [left.id] + [o.id for o in others]
+ # first pass: pick the first non-run-* id
+ for id_ in candidates:
+ if id_ and not id_.startswith(_LC_ID_PREFIX):
+ chunk_id = id_
+ break
+ else:
+ # second pass: no provider-assigned id found, just take the first non-null
+ for id_ in candidates:
+ if id_:
+ chunk_id = id_
+ break
+
+ return left.__class__(
+ content=cast("list[types.ContentBlock]", content),
+ tool_call_chunks=tool_call_chunks,
+ response_metadata=response_metadata,
+ usage_metadata=usage_metadata,
+ id=chunk_id,
+ )
+
+
+@dataclass
+class HumanMessage:
+ """A message from a human user.
+
+ Represents input from a human user in a conversation, containing text
+ or other content types like images.
+
+ Attributes:
+ id: Unique identifier for the message.
+ content: List of content blocks containing the user's input.
+ name: Optional human-readable name for the message.
+ type: Message type identifier, always "human".
+ """
+
+ id: str
+ content: list[types.ContentBlock]
+ name: Optional[str] = None
+ """An optional name for the message.
+
+ This can be used to provide a human-readable name for the message.
+
+ Usage of this field is optional, and whether it's used or not is up to the
+ model implementation.
+ """
+ type: Literal["human"] = "human"
+ """The type of the message. Must be a string that is unique to the message type.
+
+ The purpose of this field is to allow for easy identification of the message type
+ when deserializing messages.
+ """
+
+ def __init__(
+ self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None
+ ):
+ """Initialize a human message.
+
+ Args:
+ content: Message content as string or list of content blocks.
+ id: Optional unique identifier for the message.
+ """
+ self.id = _ensure_id(id)
+ if isinstance(content, str):
+ self.content = [{"type": "text", "text": content}]
+ else:
+ self.content = content
+
+ def text(self) -> str:
+ """Extract all text content from the message.
+
+ Returns:
+ Concatenated string of all text blocks in the message.
+ """
+ return "".join(
+ block["text"] for block in self.content if block["type"] == "text"
+ )
+
+
+@dataclass
+class SystemMessage:
+ """A system message containing instructions or context.
+
+ Represents system-level instructions or context that guides the AI's
+ behavior and understanding of the conversation.
+
+ Attributes:
+ id: Unique identifier for the message.
+ content: List of content blocks containing system instructions.
+ type: Message type identifier, always "system".
+ """
+
+ id: str
+ content: list[types.ContentBlock]
+ type: Literal["system"] = "system"
+
+ def __init__(
+ self, content: Union[str, list[types.ContentBlock]], *, id: Optional[str] = None
+ ):
+ """Initialize a system message.
+
+ Args:
+ content: System instructions as string or list of content blocks.
+ id: Optional unique identifier for the message.
+ """
+ self.id = _ensure_id(id)
+ if isinstance(content, str):
+ self.content = [{"type": "text", "text": content}]
+ else:
+ self.content = content
+
+ def text(self) -> str:
+ """Extract all text content from the system message."""
+ return "".join(
+ block["text"] for block in self.content if block["type"] == "text"
+ )
+
+
+@dataclass
+class ToolMessage:
+ """A message containing the result of a tool execution.
+
+ Represents the output from executing a tool or function call,
+ including the result data and execution status.
+
+ Attributes:
+ id: Unique identifier for the message.
+ tool_call_id: ID of the tool call this message responds to.
+ content: The result content from tool execution.
+ artifact: Optional app-side payload not intended for the model.
+ status: Execution status ("success" or "error").
+ type: Message type identifier, always "tool".
+ """
+
+ id: str
+ tool_call_id: str
+ content: list[dict[str, Any]]
+ artifact: Optional[Any] = None # App-side payload not for the model
+ status: Literal["success", "error"] = "success"
+ type: Literal["tool"] = "tool"
+
+ @property
+ def text(self) -> str:
+ """Extract all text content from the tool message."""
+ return "".join(
+ block["text"] for block in self.content if block["type"] == "text"
+ )
+
+ def __post_init__(self) -> None:
+ """Initialize computed fields after dataclass creation.
+
+ Ensures the tool message has a valid ID.
+ """
+ self.id = _ensure_id(self.id)
+
+
+# Alias for a message type that can be any of the defined message types
+MessageV1 = Union[
+ AIMessage,
+ AIMessageChunk,
+ HumanMessage,
+ SystemMessage,
+ ToolMessage,
+]
+MessageV1Types = get_args(MessageV1)
diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py
index 876e66b5556..0c864805b93 100644
--- a/libs/core/langchain_core/output_parsers/transform.py
+++ b/libs/core/langchain_core/output_parsers/transform.py
@@ -32,7 +32,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
def _transform(
self,
- input: Iterator[Union[str, BaseMessage]], # noqa: A002
+ input: Iterator[Union[str, BaseMessage]],
) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
@@ -42,7 +42,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
async def _atransform(
self,
- input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002
+ input: AsyncIterator[Union[str, BaseMessage]],
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py
index 5061cd3c5d0..e205afd4d8c 100644
--- a/libs/core/langchain_core/runnables/base.py
+++ b/libs/core/langchain_core/runnables/base.py
@@ -731,7 +731,7 @@ class Runnable(ABC, Generic[Input, Output]):
@abstractmethod
def invoke(
self,
- input: Input, # noqa: A002
+ input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Output:
@@ -751,7 +751,7 @@ class Runnable(ABC, Generic[Input, Output]):
async def ainvoke(
self,
- input: Input, # noqa: A002
+ input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Output:
@@ -999,7 +999,7 @@ class Runnable(ABC, Generic[Input, Output]):
def stream(
self,
- input: Input, # noqa: A002
+ input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
@@ -1019,7 +1019,7 @@ class Runnable(ABC, Generic[Input, Output]):
async def astream(
self,
- input: Input, # noqa: A002
+ input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
@@ -1073,7 +1073,7 @@ class Runnable(ABC, Generic[Input, Output]):
async def astream_log(
self,
- input: Any, # noqa: A002
+ input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: bool = True,
@@ -1144,7 +1144,7 @@ class Runnable(ABC, Generic[Input, Output]):
async def astream_events(
self,
- input: Any, # noqa: A002
+ input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1", "v2"] = "v2",
@@ -1410,7 +1410,7 @@ class Runnable(ABC, Generic[Input, Output]):
def transform(
self,
- input: Iterator[Input], # noqa: A002
+ input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
@@ -1452,7 +1452,7 @@ class Runnable(ABC, Generic[Input, Output]):
async def atransform(
self,
- input: AsyncIterator[Input], # noqa: A002
+ input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py
index 4ac7bda7b46..cc36622b914 100644
--- a/libs/core/langchain_core/runnables/config.py
+++ b/libs/core/langchain_core/runnables/config.py
@@ -402,7 +402,7 @@ def call_func_with_variable_args(
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
- input: Input, # noqa: A002
+ input: Input,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
@@ -439,7 +439,7 @@ def acall_func_with_variable_args(
Awaitable[Output],
],
],
- input: Input, # noqa: A002
+ input: Input,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any,
diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py
index 3e22494bad7..20a841d51a8 100644
--- a/libs/core/langchain_core/runnables/graph.py
+++ b/libs/core/langchain_core/runnables/graph.py
@@ -114,7 +114,7 @@ class Node(NamedTuple):
def copy(
self,
*,
- id: Optional[str] = None, # noqa: A002
+ id: Optional[str] = None,
name: Optional[str] = None,
) -> Node:
"""Return a copy of the node with optional new id and name.
@@ -187,7 +187,7 @@ class MermaidDrawMethod(Enum):
def node_data_str(
- id: str, # noqa: A002
+ id: str,
data: Union[type[BaseModel], RunnableType, None],
) -> str:
"""Convert the data of a node to a string.
@@ -328,7 +328,7 @@ class Graph:
def add_node(
self,
data: Union[type[BaseModel], RunnableType, None],
- id: Optional[str] = None, # noqa: A002
+ id: Optional[str] = None,
*,
metadata: Optional[dict[str, Any]] = None,
) -> Node:
diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py
index d7059fded47..b69e2c331fa 100644
--- a/libs/core/langchain_core/utils/function_calling.py
+++ b/libs/core/langchain_core/utils/function_calling.py
@@ -616,7 +616,7 @@ def convert_to_json_schema(
@beta()
def tool_example_to_messages(
- input: str, # noqa: A002
+ input: str,
tool_calls: list[BaseModel],
tool_outputs: Optional[list[str]] = None,
*,
diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml
index 9b1686659ef..66a4c5bc3b4 100644
--- a/libs/core/pyproject.toml
+++ b/libs/core/pyproject.toml
@@ -86,6 +86,7 @@ ignore = [
"FIX002", # Line contains TODO
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
+ "PLC0414", # Enable re-export
"PLR09", # Too many something (arg, statements, etc)
"RUF012", # Doesn't play well with Pydantic
"TC001", # Doesn't play well with Pydantic
@@ -105,6 +106,7 @@ unfixable = ["PLW1510",]
flake8-annotations.allow-star-arg-any = true
flake8-annotations.mypy-init-return = true
+flake8-builtins.ignorelist = ["id", "input", "type"]
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
pydocstyle.convention = "google"
diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py
index ce262797535..4d17f3b8cae 100644
--- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py
+++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py
@@ -15,6 +15,8 @@ from langchain_core.language_models import (
ParrotFakeChatModel,
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
+from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
+from langchain_core.messages.v1 import MessageV1
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
_any_id_ai_message,
@@ -157,13 +159,13 @@ async def test_callback_handlers() -> None:
"""Verify that model is implemented correctly with handlers working."""
class MyCustomAsyncHandler(AsyncCallbackHandler):
- def __init__(self, store: list[str]) -> None:
+ def __init__(self, store: list[Union[str, AIMessageChunkV1]]) -> None:
self.store = store
async def on_chat_model_start(
self,
serialized: dict[str, Any],
- messages: list[list[BaseMessage]],
+ messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -178,9 +180,11 @@ async def test_callback_handlers() -> None:
@override
async def on_llm_new_token(
self,
- token: str,
+ token: Union[str, AIMessageChunkV1],
*,
- chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
+ chunk: Optional[
+ Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
+ ] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
@@ -194,7 +198,7 @@ async def test_callback_handlers() -> None:
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
- tokens: list[str] = []
+ tokens: list[Union[str, AIMessageChunkV1]] = []
# New model
results = [
chunk
diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py
index 1cceb0a146b..d611fc36c4e 100644
--- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py
+++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py
@@ -300,8 +300,9 @@ def test_llm_representation_for_serializable() -> None:
assert chat._get_llm_string() == (
'{"id": ["tests", "unit_tests", "language_models", "chat_models", '
'"test_cache", "CustomChat"], "kwargs": {"messages": {"id": '
- '["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": '
- '1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]'
+ '["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}, '
+ '"output_version": "v0"}, "lc": 1, "name": "CustomChat", "type": '
+ "\"constructor\"}---[('stop', None)]"
)
diff --git a/libs/core/tests/unit_tests/messages/test_imports.py b/libs/core/tests/unit_tests/messages/test_imports.py
index ff9fbf92fc7..9fda5493244 100644
--- a/libs/core/tests/unit_tests/messages/test_imports.py
+++ b/libs/core/tests/unit_tests/messages/test_imports.py
@@ -5,22 +5,40 @@ EXPECTED_ALL = [
"_message_from_dict",
"AIMessage",
"AIMessageChunk",
+ "Annotation",
"AnyMessage",
+ "AudioContentBlock",
"BaseMessage",
"BaseMessageChunk",
+ "ContentBlock",
"ChatMessage",
"ChatMessageChunk",
+ "Citation",
+ "CodeInterpreterCall",
+ "CodeInterpreterOutput",
+ "CodeInterpreterResult",
+ "DataContentBlock",
+ "FileContentBlock",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
+ "ImageContentBlock",
"InvalidToolCall",
+ "NonStandardAnnotation",
+ "NonStandardContentBlock",
+ "PlainTextContentBlock",
"SystemMessage",
"SystemMessageChunk",
+ "TextContentBlock",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
+ "VideoContentBlock",
+ "WebSearchCall",
+ "WebSearchResult",
+ "ReasoningContentBlock",
"RemoveMessage",
"convert_to_messages",
"get_buffer_string",
diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py
index bedd518589e..f9f1c9c9ff0 100644
--- a/libs/core/tests/unit_tests/messages/test_utils.py
+++ b/libs/core/tests/unit_tests/messages/test_utils.py
@@ -1221,15 +1221,30 @@ def test_convert_to_openai_messages_multimodal() -> None:
{"type": "text", "text": "Text message"},
{
"type": "image",
- "source_type": "url",
"url": "https://example.com/test.png",
},
+ {
+ "type": "image",
+ "source_type": "url", # backward compatibility
+ "url": "https://example.com/test.png",
+ },
+ {
+ "type": "image",
+ "base64": "",
+ "mime_type": "image/png",
+ },
{
"type": "image",
"source_type": "base64",
"data": "",
"mime_type": "image/png",
},
+ {
+ "type": "file",
+ "base64": "",
+ "mime_type": "application/pdf",
+ "filename": "test.pdf",
+ },
{
"type": "file",
"source_type": "base64",
@@ -1244,11 +1259,20 @@ def test_convert_to_openai_messages_multimodal() -> None:
"file_data": "data:application/pdf;base64,",
},
},
+ {
+ "type": "file",
+ "file_id": "file-abc123",
+ },
{
"type": "file",
"source_type": "id",
"id": "file-abc123",
},
+ {
+ "type": "audio",
+ "base64": "",
+ "mime_type": "audio/wav",
+ },
{
"type": "audio",
"source_type": "base64",
@@ -1268,7 +1292,7 @@ def test_convert_to_openai_messages_multimodal() -> None:
result = convert_to_openai_messages(messages, text_format="block")
assert len(result) == 1
message = result[0]
- assert len(message["content"]) == 8
+ assert len(message["content"]) == 13
# Test adding filename
messages = [
@@ -1276,8 +1300,7 @@ def test_convert_to_openai_messages_multimodal() -> None:
content=[
{
"type": "file",
- "source_type": "base64",
- "data": "",
+ "base64": "",
"mime_type": "application/pdf",
},
]
diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr
index 7c07416fe5d..7d844b6beca 100644
--- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr
+++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr
@@ -785,6 +785,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -1015,6 +1016,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -1029,6 +1034,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
@@ -2217,6 +2223,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -2447,6 +2454,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -2461,6 +2472,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr
index a788c425fce..7b8f1b570a5 100644
--- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr
+++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr
@@ -1188,6 +1188,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -1418,6 +1419,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -1432,6 +1437,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr
index 079e4909061..cc9f7aab15d 100644
--- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr
+++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr
@@ -2732,6 +2732,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -2960,6 +2961,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -2973,6 +2978,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
@@ -4208,6 +4214,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -4455,6 +4462,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -4468,6 +4479,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
@@ -5715,6 +5727,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -5962,6 +5975,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -5975,6 +5992,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
@@ -7097,6 +7115,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -7325,6 +7344,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -7338,6 +7361,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
@@ -8615,6 +8639,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -8862,6 +8887,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -8875,6 +8904,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
@@ -10042,6 +10072,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -10270,6 +10301,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -10283,6 +10318,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
@@ -11468,6 +11504,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -11726,6 +11763,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -11739,6 +11780,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
@@ -12936,6 +12978,7 @@
'args',
'id',
'error',
+ 'type',
]),
'title': 'InvalidToolCall',
'type': 'object',
@@ -13183,6 +13226,10 @@
]),
'title': 'Id',
}),
+ 'index': dict({
+ 'title': 'Index',
+ 'type': 'integer',
+ }),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -13196,6 +13243,7 @@
'name',
'args',
'id',
+ 'type',
]),
'title': 'ToolCall',
'type': 'object',
diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py
index 58822433d7a..75bbb804c2e 100644
--- a/libs/core/tests/unit_tests/test_messages.py
+++ b/libs/core/tests/unit_tests/test_messages.py
@@ -30,9 +30,11 @@ from langchain_core.messages import (
messages_from_dict,
messages_to_dict,
)
+from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
+from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.utils._merge import merge_lists
@@ -195,6 +197,116 @@ def test_message_chunks() -> None:
assert (meaningful_id + default_id).id == "msg_def456"
+def test_message_chunks_v2() -> None:
+ left = AIMessageChunkV1("foo ", id="abc")
+ right = AIMessageChunkV1("bar")
+ expected = AIMessageChunkV1("foo bar", id="abc")
+ assert left + right == expected
+
+ # Test tool calls
+ one = AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(name="tool1", args="", id="1", index=0)
+ ],
+ )
+ two = AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(name=None, args='{"arg1": "val', id=None, index=0)
+ ],
+ )
+ three = AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(name=None, args='ue}"', id=None, index=0)
+ ],
+ )
+ expected = AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(
+ name="tool1", args='{"arg1": "value}"', id="1", index=0
+ )
+ ],
+ )
+ assert one + two + three == expected
+
+ assert (
+ AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(name="tool1", args="", id="1", index=0)
+ ],
+ )
+ + AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(name="tool1", args="a", id=None, index=1)
+ ],
+ )
+ # Don't merge if `index` field does not match.
+ ) == AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(name="tool1", args="", id="1", index=0),
+ create_tool_call_chunk(name="tool1", args="a", id=None, index=1),
+ ],
+ )
+
+ ai_msg_chunk = AIMessageChunkV1([])
+ tool_calls_msg_chunk = AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(name="tool1", args="a", id=None, index=1)
+ ],
+ )
+ assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
+ assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
+
+ ai_msg_chunk = AIMessageChunkV1(
+ [],
+ tool_call_chunks=[
+ create_tool_call_chunk(name="tool1", args="", id="1", index=0)
+ ],
+ )
+ assert ai_msg_chunk.tool_calls == [create_tool_call(name="tool1", args={}, id="1")]
+
+ # Test token usage
+ left = AIMessageChunkV1(
+ [],
+ usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
+ )
+ right = AIMessageChunkV1(
+ [],
+ usage_metadata={"input_tokens": 4, "output_tokens": 5, "total_tokens": 9},
+ )
+ assert left + right == AIMessageChunkV1(
+ content=[],
+ usage_metadata={"input_tokens": 5, "output_tokens": 7, "total_tokens": 12},
+ )
+ assert AIMessageChunkV1(content=[]) + left == left
+ assert right + AIMessageChunkV1(content=[]) == right
+
+ # Test ID order of precedence
+ null_id = AIMessageChunkV1(content=[], id=None)
+ default_id = AIMessageChunkV1(
+ content=[], id="run-abc123"
+ ) # LangChain-assigned run ID
+ meaningful_id = AIMessageChunkV1(
+ content=[], id="msg_def456"
+ ) # provider-assigned ID
+
+ assert (null_id + default_id).id == "run-abc123"
+ assert (default_id + null_id).id == "run-abc123"
+
+ assert (null_id + meaningful_id).id == "msg_def456"
+ assert (meaningful_id + null_id).id == "msg_def456"
+
+ assert (default_id + meaningful_id).id == "msg_def456"
+ assert (meaningful_id + default_id).id == "msg_def456"
+
+
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(
role="User", content=" indeed."
@@ -1111,23 +1223,20 @@ def test_is_data_content_block() -> None:
assert is_data_content_block(
{
"type": "image",
- "source_type": "url",
"url": "https://...",
}
)
assert is_data_content_block(
{
"type": "image",
- "source_type": "base64",
- "data": "",
+ "base64": "",
"mime_type": "image/jpeg",
}
)
assert is_data_content_block(
{
"type": "image",
- "source_type": "base64",
- "data": "",
+ "base64": "",
"mime_type": "image/jpeg",
"cache_control": {"type": "ephemeral"},
}
@@ -1135,13 +1244,17 @@ def test_is_data_content_block() -> None:
assert is_data_content_block(
{
"type": "image",
- "source_type": "base64",
- "data": "",
+ "base64": "",
"mime_type": "image/jpeg",
"metadata": {"cache_control": {"type": "ephemeral"}},
}
)
-
+ assert is_data_content_block(
+ {
+ "type": "image",
+ "source_type": "base64", # backward compatibility
+ }
+ )
assert not is_data_content_block(
{
"type": "text",
@@ -1154,12 +1267,6 @@ def test_is_data_content_block() -> None:
"image_url": {"url": "https://..."},
}
)
- assert not is_data_content_block(
- {
- "type": "image",
- "source_type": "base64",
- }
- )
assert not is_data_content_block(
{
"type": "image",
@@ -1169,31 +1276,65 @@ def test_is_data_content_block() -> None:
def test_convert_to_openai_image_block() -> None:
- input_block = {
- "type": "image",
- "source_type": "url",
- "url": "https://...",
- "cache_control": {"type": "ephemeral"},
- }
- expected = {
- "type": "image_url",
- "image_url": {"url": "https://..."},
- }
- result = convert_to_openai_image_block(input_block)
- assert result == expected
-
- input_block = {
- "type": "image",
- "source_type": "base64",
- "data": "",
- "mime_type": "image/jpeg",
- "cache_control": {"type": "ephemeral"},
- }
- expected = {
- "type": "image_url",
- "image_url": {
- "url": "data:image/jpeg;base64,",
+ for input_block in [
+ {
+ "type": "image",
+ "url": "https://...",
+ "cache_control": {"type": "ephemeral"},
},
- }
- result = convert_to_openai_image_block(input_block)
- assert result == expected
+ {
+ "type": "image",
+ "source_type": "url",
+ "url": "https://...",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ]:
+ expected = {
+ "type": "image_url",
+ "image_url": {"url": "https://..."},
+ }
+ result = convert_to_openai_image_block(input_block)
+ assert result == expected
+
+ for input_block in [
+ {
+ "type": "image",
+ "base64": "",
+ "mime_type": "image/jpeg",
+ "cache_control": {"type": "ephemeral"},
+ },
+ {
+ "type": "image",
+ "source_type": "base64",
+ "data": "",
+ "mime_type": "image/jpeg",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ]:
+ expected = {
+ "type": "image_url",
+ "image_url": {
+ "url": "data:image/jpeg;base64,",
+ },
+ }
+ result = convert_to_openai_image_block(input_block)
+ assert result == expected
+
+
+def test_known_block_types() -> None:
+ assert {
+ "text",
+ "text-plain",
+ "tool_call",
+ "reasoning",
+ "non_standard",
+ "image",
+ "audio",
+ "file",
+ "video",
+ "code_interpreter_call",
+ "code_interpreter_output",
+ "code_interpreter_result",
+ "web_search_call",
+ "web_search_result",
+ } == KNOWN_BLOCK_TYPES
diff --git a/libs/langchain/langchain/agents/output_parsers/tools.py b/libs/langchain/langchain/agents/output_parsers/tools.py
index 10ebd3c3dfc..62759707b38 100644
--- a/libs/langchain/langchain/agents/output_parsers/tools.py
+++ b/libs/langchain/langchain/agents/output_parsers/tools.py
@@ -47,7 +47,12 @@ def parse_ai_message_to_tool_action(
try:
args = json.loads(function["arguments"] or "{}")
tool_calls.append(
- ToolCall(name=function_name, args=args, id=tool_call["id"]),
+ ToolCall(
+ name=function_name,
+ args=args,
+ id=tool_call["id"],
+ type="tool_call",
+ )
)
except JSONDecodeError as e:
msg = (
diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py
index 04d33d12f4e..089f31f47f7 100644
--- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py
+++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py
@@ -53,7 +53,12 @@ def test_calls_convert_agent_action_to_messages() -> None:
message4 = AIMessage(
content="",
tool_calls=[
- ToolCall(name="exponentiate", args={"a": 3, "b": 5}, id="call_abc02468"),
+ ToolCall(
+ name="exponentiate",
+ args={"a": 3, "b": 5},
+ id="call_abc02468",
+ type="tool_call",
+ ),
],
)
actions4 = parse_ai_message_to_openai_tool_action(message4)
diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py
index ea6c455db69..d1bc7e6c2a0 100644
--- a/libs/langchain/tests/unit_tests/agents/test_agent.py
+++ b/libs/langchain/tests/unit_tests/agents/test_agent.py
@@ -1008,7 +1008,7 @@ def _make_tools_invocation(name_to_arguments: dict[str, dict[str, Any]]) -> AIMe
for idx, (name, arguments) in enumerate(name_to_arguments.items())
]
tool_calls = [
- ToolCall(name=name, args=args, id=str(idx))
+ ToolCall(name=name, args=args, id=str(idx), type="tool_call")
for idx, (name, args) in enumerate(name_to_arguments.items())
]
return AIMessage(
diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py
index 18279c95e67..9b6a6b44321 100644
--- a/libs/langchain/tests/unit_tests/chat_models/test_base.py
+++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py
@@ -205,6 +205,7 @@ def test_configurable_with_default() -> None:
"name": None,
"bound": {
"name": None,
+ "output_version": "v0",
"disable_streaming": False,
"model": "claude-3-sonnet-20240229",
"mcp_servers": None,
diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py
index 25ff3eb607c..a35a88884dd 100644
--- a/libs/partners/openai/langchain_openai/chat_models/_compat.py
+++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py
@@ -1,7 +1,10 @@
"""
-This module converts between AIMessage output formats for the Responses API.
+This module converts between AIMessage output formats, which are governed by the
+``output_version`` attribute on ChatOpenAI. Supported values are ``"v0"``,
+``"responses/v1"``, and ``"v1"``.
-ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs:
+``"v0"`` corresponds to the format as of ChatOpenAI v0.3. For the Responses API, it
+stores reasoning and tool outputs in AIMessage.additional_kwargs:
.. code-block:: python
@@ -28,8 +31,9 @@ ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs
id="msg_123",
)
-To retain information about response item sequencing (and to accommodate multiple
-reasoning items), ChatOpenAI now stores these items in the content sequence:
+``"responses/v1"`` is only applicable to the Responses API. It retains information
+about response item sequencing and accommodates multiple reasoning items by
+representing these items in the content sequence:
.. code-block:: python
@@ -56,19 +60,23 @@ reasoning items), ChatOpenAI now stores these items in the content sequence:
There are other, small improvements as well-- e.g., we store message IDs on text
content blocks, rather than on the AIMessage.id, which now stores the response ID.
+``"v1"`` represents LangChain's cross-provider standard format.
+
For backwards compatibility, this module provides functions to convert between the
old and new formats. The functions are used internally by ChatOpenAI.
""" # noqa: E501
import json
-from typing import Union
+from collections.abc import Iterable, Iterator
+from typing import Any, Literal, Union, cast
-from langchain_core.messages import AIMessage
+from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
+# v0.3 / Responses
def _convert_to_v03_ai_message(
message: AIMessage, has_reasoning: bool = False
) -> AIMessage:
@@ -253,3 +261,423 @@ def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage:
},
deep=False,
)
+
+
+# v1 / Chat Completions
+def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage:
+ """Mutate a Chat Completions message to v1 format."""
+ if isinstance(message.content, str):
+ if message.content:
+ message.content = [{"type": "text", "text": message.content}]
+ else:
+ message.content = []
+
+ for tool_call in message.tool_calls:
+ if id_ := tool_call.get("id"):
+ message.content.append({"type": "tool_call", "id": id_})
+
+ if "tool_calls" in message.additional_kwargs:
+ _ = message.additional_kwargs.pop("tool_calls")
+
+ if "token_usage" in message.response_metadata:
+ _ = message.response_metadata.pop("token_usage")
+
+ return message
+
+
+def _convert_to_v1_from_chat_completions_chunk(chunk: AIMessageChunk) -> AIMessageChunk:
+ result = _convert_to_v1_from_chat_completions(cast(AIMessage, chunk))
+ return cast(AIMessageChunk, result)
+
+
+def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage:
+ """Convert a v1 message to the Chat Completions format."""
+ if isinstance(message.content, list):
+ new_content: list = []
+ for block in message.content:
+ if isinstance(block, dict):
+ block_type = block.get("type")
+ if block_type == "text":
+ # Strip annotations
+ new_content.append({"type": "text", "text": block["text"]})
+ elif block_type in ("reasoning", "tool_call"):
+ pass
+ else:
+ new_content.append(block)
+ else:
+ new_content.append(block)
+ return message.model_copy(update={"content": new_content})
+
+ return message
+
+
+# v1 / Responses
+def _convert_annotation_to_v1(annotation: dict[str, Any]) -> dict[str, Any]:
+ annotation_type = annotation.get("type")
+
+ if annotation_type == "url_citation":
+ url_citation = {}
+ for field in ("end_index", "start_index", "title"):
+ if field in annotation:
+ url_citation[field] = annotation[field]
+ url_citation["type"] = "url_citation"
+ url_citation["url"] = annotation["url"]
+ return url_citation
+
+ elif annotation_type == "file_citation":
+ document_citation = {"type": "document_citation"}
+ if "filename" in annotation:
+ document_citation["title"] = annotation["filename"]
+ for field in ("file_id", "index"): # OpenAI-specific
+ if field in annotation:
+ document_citation[field] = annotation[field]
+ return document_citation
+
+ # TODO: standardise container_file_citation?
+ else:
+ non_standard_annotation = {
+ "type": "non_standard_annotation",
+ "value": annotation,
+ }
+ return non_standard_annotation
+
+
+def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]:
+ if block.get("type") != "reasoning" or "summary" not in block:
+ yield block
+ return
+
+ if not block["summary"]:
+ _ = block.pop("summary", None)
+ yield block
+ return
+
+ # Common part for every exploded line, except 'summary'
+ common = {k: v for k, v in block.items() if k != "summary"}
+
+ # Optional keys that must appear only in the first exploded item
+ first_only = {
+ k: common.pop(k) for k in ("encrypted_content", "status") if k in common
+ }
+
+ for idx, part in enumerate(block["summary"]):
+ new_block = dict(common)
+ new_block["reasoning"] = part.get("text", "")
+ if idx == 0:
+ new_block.update(first_only)
+ yield new_block
+
+
+def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
+ """Mutate a Responses message to v1 format."""
+ if not isinstance(message.content, list):
+ return message
+
+ def _iter_blocks() -> Iterable[dict[str, Any]]:
+ for block in message.content:
+ if not isinstance(block, dict):
+ continue
+ block_type = block.get("type")
+
+ if block_type == "text":
+ if "annotations" in block:
+ block["annotations"] = [
+ _convert_annotation_to_v1(a) for a in block["annotations"]
+ ]
+ yield block
+
+ elif block_type == "reasoning":
+ yield from _explode_reasoning(block)
+
+ elif block_type == "image_generation_call" and (
+ result := block.get("result")
+ ):
+ new_block = {"type": "image", "base64": result}
+ if output_format := block.get("output_format"):
+ new_block["mime_type"] = f"image/{output_format}"
+ for extra_key in (
+ "id",
+ "index",
+ "status",
+ "background",
+ "output_format",
+ "quality",
+ "revised_prompt",
+ "size",
+ ):
+ if extra_key in block:
+ new_block[extra_key] = block[extra_key]
+ yield new_block
+
+ elif block_type == "function_call":
+ new_block = {"type": "tool_call", "id": block.get("call_id", "")}
+ if "id" in block:
+ new_block["item_id"] = block["id"]
+ for extra_key in ("arguments", "name", "index"):
+ if extra_key in block:
+ new_block[extra_key] = block[extra_key]
+ yield new_block
+
+ elif block_type == "web_search_call":
+ web_search_call = {"type": "web_search_call", "id": block["id"]}
+ if "index" in block:
+ web_search_call["index"] = block["index"]
+ if (
+ "action" in block
+ and isinstance(block["action"], dict)
+ and block["action"].get("type") == "search"
+ and "query" in block["action"]
+ ):
+ web_search_call["query"] = block["action"]["query"]
+ for key in block:
+ if key not in ("type", "id"):
+ web_search_call[key] = block[key]
+
+ web_search_result = {"type": "web_search_result", "id": block["id"]}
+ if "index" in block:
+ web_search_result["index"] = block["index"] + 1
+ yield web_search_call
+ yield web_search_result
+
+ elif block_type == "code_interpreter_call":
+ code_interpreter_call = {
+ "type": "code_interpreter_call",
+ "id": block["id"],
+ }
+ if "code" in block:
+ code_interpreter_call["code"] = block["code"]
+ if "container_id" in block:
+ code_interpreter_call["container_id"] = block["container_id"]
+ if "index" in block:
+ code_interpreter_call["index"] = block["index"]
+
+ code_interpreter_result = {
+ "type": "code_interpreter_result",
+ "id": block["id"],
+ }
+ if "outputs" in block:
+ code_interpreter_result["outputs"] = block["outputs"]
+ for output in block["outputs"]:
+ if (
+ isinstance(output, dict)
+ and (output_type := output.get("type"))
+ and output_type == "logs"
+ ):
+ if "output" not in code_interpreter_result:
+ code_interpreter_result["output"] = []
+ code_interpreter_result["output"].append(
+ {
+ "type": "code_interpreter_output",
+ "stdout": output.get("logs", ""),
+ }
+ )
+
+ if "status" in block:
+ code_interpreter_result["status"] = block["status"]
+ if "index" in block:
+ code_interpreter_result["index"] = block["index"] + 1
+
+ yield code_interpreter_call
+ yield code_interpreter_result
+
+ else:
+ new_block = {"type": "non_standard", "value": block}
+ if "index" in new_block["value"]:
+ new_block["index"] = new_block["value"].pop("index")
+ yield new_block
+
+ # Replace the list with the fully converted one
+ message.content = list(_iter_blocks())
+
+ return message
+
+
+def _convert_annotation_from_v1(annotation: dict[str, Any]) -> dict[str, Any]:
+ annotation_type = annotation.get("type")
+
+ if annotation_type == "document_citation":
+ new_ann: dict[str, Any] = {"type": "file_citation"}
+
+ if "title" in annotation:
+ new_ann["filename"] = annotation["title"]
+
+ for fld in ("file_id", "index"):
+ if fld in annotation:
+ new_ann[fld] = annotation[fld]
+
+ return new_ann
+
+ elif annotation_type == "non_standard_annotation":
+ return annotation["value"]
+
+ else:
+ return dict(annotation)
+
+
+def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str, Any]]:
+ i = 0
+ n = len(blocks)
+
+ while i < n:
+ block = blocks[i]
+
+ # Skip non-reasoning blocks or blocks already in Responses format
+ if block.get("type") != "reasoning" or "summary" in block:
+ yield dict(block)
+ i += 1
+ continue
+ elif "reasoning" not in block and "summary" not in block:
+ # {"type": "reasoning", "id": "rs_..."}
+ oai_format = {**block, "summary": []}
+ oai_format["type"] = oai_format.pop("type", "reasoning")
+ yield oai_format
+ i += 1
+ continue
+ else:
+ pass
+
+ summary: list[dict[str, str]] = [
+ {"type": "summary_text", "text": block.get("reasoning", "")}
+ ]
+ # 'common' is every field except the exploded 'reasoning'
+ common = {k: v for k, v in block.items() if k != "reasoning"}
+
+ i += 1
+ while i < n:
+ next_ = blocks[i]
+ if next_.get("type") == "reasoning" and "reasoning" in next_:
+ summary.append(
+ {"type": "summary_text", "text": next_.get("reasoning", "")}
+ )
+ i += 1
+ else:
+ break
+
+ merged = dict(common)
+ merged["summary"] = summary
+ merged["type"] = merged.pop("type", "reasoning")
+ yield merged
+
+
+def _consolidate_calls(
+ items: Iterable[dict[str, Any]],
+ call_name: Literal["web_search_call", "code_interpreter_call"],
+ result_name: Literal["web_search_result", "code_interpreter_result"],
+) -> Iterator[dict[str, Any]]:
+ """
+ Generator that walks through *items* and, whenever it meets the pair
+
+ {"type": "web_search_call", "id": X, ...}
+ {"type": "web_search_result", "id": X}
+
+ merges them into
+
+ {"id": X,
+ "action": …,
+ "status": …,
+ "type": "web_search_call"}
+
+ keeping every other element untouched.
+ """
+ items = iter(items) # make sure we have a true iterator
+ for current in items:
+ # Only a call can start a pair worth collapsing
+ if current.get("type") != call_name:
+ yield current
+ continue
+
+ try:
+ nxt = next(items) # look-ahead one element
+ except StopIteration: # no “result” – just yield the call back
+ yield current
+ break
+
+ # If this really is the matching “result” – collapse
+ if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
+ if call_name == "web_search_call":
+ collapsed = {
+ "id": current["id"],
+ "status": current["status"],
+ "type": "web_search_call",
+ }
+ if "action" in current:
+ collapsed["action"] = current["action"]
+
+ if call_name == "code_interpreter_call":
+ collapsed = {"id": current["id"]}
+ for key in ("code", "container_id"):
+ if key in current:
+ collapsed[key] = current[key]
+
+ for key in ("outputs", "status"):
+ if key in nxt:
+ collapsed[key] = nxt[key]
+ collapsed["type"] = "code_interpreter_call"
+
+ yield collapsed
+
+ else:
+ # Not a matching pair – emit both, in original order
+ yield current
+ yield nxt
+
+
+def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
+ if not isinstance(message.content, list):
+ return message
+
+ new_content: list = []
+ for block in message.content:
+ if isinstance(block, dict):
+ block_type = block.get("type")
+ if block_type == "text" and "annotations" in block:
+ # Need a copy because we’re changing the annotations list
+ new_block = dict(block)
+ new_block["annotations"] = [
+ _convert_annotation_from_v1(a) for a in block["annotations"]
+ ]
+ new_content.append(new_block)
+ elif block_type == "tool_call":
+ new_block = {"type": "function_call", "call_id": block["id"]}
+ if "item_id" in block:
+ new_block["id"] = block["item_id"]
+ if "name" in block and "arguments" in block:
+ new_block["name"] = block["name"]
+ new_block["arguments"] = block["arguments"]
+ else:
+ tool_call = next(
+ call for call in message.tool_calls if call["id"] == block["id"]
+ )
+ if "name" not in block:
+ new_block["name"] = tool_call["name"]
+ if "arguments" not in block:
+ new_block["arguments"] = json.dumps(tool_call["args"])
+ new_content.append(new_block)
+ elif (
+ is_data_content_block(block)
+ and block["type"] == "image"
+ and "base64" in block
+ ):
+ new_block = {"type": "image_generation_call", "result": block["base64"]}
+ for extra_key in ("id", "status"):
+ if extra_key in block:
+ new_block[extra_key] = block[extra_key]
+ new_content.append(new_block)
+ elif block_type == "non_standard" and "value" in block:
+ new_content.append(block["value"])
+ else:
+ new_content.append(block)
+ else:
+ new_content.append(block)
+
+ new_content = list(_implode_reasoning_blocks(new_content))
+ new_content = list(
+ _consolidate_calls(new_content, "web_search_call", "web_search_result")
+ )
+ new_content = list(
+ _consolidate_calls(
+ new_content, "code_interpreter_call", "code_interpreter_result"
+ )
+ )
+
+ return message.model_copy(update={"content": new_content})
diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py
index 1bc9b66d880..ca4118fb6a3 100644
--- a/libs/partners/openai/langchain_openai/chat_models/base.py
+++ b/libs/partners/openai/langchain_openai/chat_models/base.py
@@ -108,7 +108,12 @@ from langchain_openai.chat_models._client_utils import (
)
from langchain_openai.chat_models._compat import (
_convert_from_v03_ai_message,
+ _convert_from_v1_to_chat_completions,
+ _convert_from_v1_to_responses,
_convert_to_v03_ai_message,
+ _convert_to_v1_from_chat_completions,
+ _convert_to_v1_from_chat_completions_chunk,
+ _convert_to_v1_from_responses,
)
if TYPE_CHECKING:
@@ -669,7 +674,7 @@ class BaseChatOpenAI(BaseChatModel):
.. versionadded:: 0.3.9
"""
- output_version: Literal["v0", "responses/v1"] = "v0"
+ output_version: str = "v0"
"""Version of AIMessage output format to use.
This field is used to roll-out new output formats for chat model AIMessages
@@ -680,6 +685,7 @@ class BaseChatOpenAI(BaseChatModel):
- ``'v0'``: AIMessage format as of langchain-openai 0.3.x.
- ``'responses/v1'``: Formats Responses API output
items into AIMessage content blocks.
+ - ``"v1"``: v1 of LangChain cross-provider standard.
Currently only impacts the Responses API. ``output_version='responses/v1'`` is
recommended.
@@ -869,6 +875,10 @@ class BaseChatOpenAI(BaseChatModel):
message=default_chunk_class(content="", usage_metadata=usage_metadata),
generation_info=base_generation_info,
)
+ if self.output_version == "v1":
+ generation_chunk.message = _convert_to_v1_from_chat_completions_chunk(
+ cast(AIMessageChunk, generation_chunk.message)
+ )
return generation_chunk
choice = choices[0]
@@ -896,6 +906,20 @@ class BaseChatOpenAI(BaseChatModel):
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
+ if self.output_version == "v1":
+ message_chunk = cast(AIMessageChunk, message_chunk)
+ # Convert to v1 format
+ if isinstance(message_chunk.content, str):
+ message_chunk = _convert_to_v1_from_chat_completions_chunk(
+ message_chunk
+ )
+ if message_chunk.content:
+ message_chunk.content[0]["index"] = 0 # type: ignore[index]
+ else:
+ message_chunk = _convert_to_v1_from_chat_completions_chunk(
+ message_chunk
+ )
+
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
@@ -1188,7 +1212,12 @@ class BaseChatOpenAI(BaseChatModel):
else:
payload = _construct_responses_api_payload(messages, payload)
else:
- payload["messages"] = [_convert_message_to_dict(m) for m in messages]
+ payload["messages"] = [
+ _convert_message_to_dict(_convert_from_v1_to_chat_completions(m))
+ if isinstance(m, AIMessage)
+ else _convert_message_to_dict(m)
+ for m in messages
+ ]
return payload
def _create_chat_result(
@@ -1254,6 +1283,11 @@ class BaseChatOpenAI(BaseChatModel):
if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal
+ if self.output_version == "v1":
+ _ = llm_output.pop("token_usage", None)
+ generations[0].message = _convert_to_v1_from_chat_completions(
+ cast(AIMessage, generations[0].message)
+ )
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
@@ -3577,6 +3611,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
for lc_msg in messages:
if isinstance(lc_msg, AIMessage):
lc_msg = _convert_from_v03_ai_message(lc_msg)
+ lc_msg = _convert_from_v1_to_responses(lc_msg)
msg = _convert_message_to_dict(lc_msg)
# "name" parameter unsupported
if "name" in msg:
@@ -3720,7 +3755,7 @@ def _construct_lc_result_from_responses_api(
response: Response,
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
- output_version: Literal["v0", "responses/v1"] = "v0",
+ output_version: str = "v0",
) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
@@ -3859,6 +3894,27 @@ def _construct_lc_result_from_responses_api(
)
if output_version == "v0":
message = _convert_to_v03_ai_message(message)
+ elif output_version == "v1":
+ message = _convert_to_v1_from_responses(message)
+ if response.tools and any(
+ tool.type == "image_generation" for tool in response.tools
+ ):
+ # Get mime_time from tool definition and add to image generations
+ # if missing (primarily for tracing purposes).
+ image_generation_call = next(
+ tool for tool in response.tools if tool.type == "image_generation"
+ )
+ if image_generation_call.output_format:
+ mime_type = f"image/{image_generation_call.output_format}"
+ for content_block in message.content:
+ # OK to mutate output message
+ if (
+ isinstance(content_block, dict)
+ and content_block.get("type") == "image"
+ and "base64" in content_block
+ and "mime_type" not in block
+ ):
+ block["mime_type"] = mime_type
else:
pass
return ChatResult(generations=[ChatGeneration(message=message)])
@@ -3872,7 +3928,7 @@ def _convert_responses_chunk_to_generation_chunk(
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
has_reasoning: bool = False,
- output_version: Literal["v0", "responses/v1"] = "v0",
+ output_version: str = "v0",
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
"""Advance indexes tracked during streaming.
@@ -3938,9 +3994,29 @@ def _convert_responses_chunk_to_generation_chunk(
annotation = chunk.annotation
else:
annotation = chunk.annotation.model_dump(exclude_none=True, mode="json")
- content.append({"annotations": [annotation], "index": current_index})
+ if output_version == "v1":
+ content.append(
+ {
+ "type": "text",
+ "text": "",
+ "annotations": [annotation],
+ "index": current_index,
+ }
+ )
+ else:
+ content.append({"annotations": [annotation], "index": current_index})
elif chunk.type == "response.output_text.done":
- content.append({"id": chunk.item_id, "index": current_index})
+ if output_version == "v1":
+ content.append(
+ {
+ "type": "text",
+ "text": "",
+ "id": chunk.item_id,
+ "index": current_index,
+ }
+ )
+ else:
+ content.append({"id": chunk.item_id, "index": current_index})
elif chunk.type == "response.created":
id = chunk.response.id
response_metadata["id"] = chunk.response.id # Backwards compatibility
@@ -4016,21 +4092,34 @@ def _convert_responses_chunk_to_generation_chunk(
content.append({"type": "refusal", "refusal": chunk.refusal})
elif chunk.type == "response.output_item.added" and chunk.item.type == "reasoning":
_advance(chunk.output_index)
+ current_sub_index = 0
reasoning = chunk.item.model_dump(exclude_none=True, mode="json")
reasoning["index"] = current_index
content.append(reasoning)
elif chunk.type == "response.reasoning_summary_part.added":
- _advance(chunk.output_index)
- content.append(
- {
- # langchain-core uses the `index` key to aggregate text blocks.
- "summary": [
- {"index": chunk.summary_index, "type": "summary_text", "text": ""}
- ],
- "index": current_index,
- "type": "reasoning",
- }
- )
+ if output_version in ("v0", "responses/v1"):
+ _advance(chunk.output_index)
+ content.append(
+ {
+ # langchain-core uses the `index` key to aggregate text blocks.
+ "summary": [
+ {
+ "index": chunk.summary_index,
+ "type": "summary_text",
+ "text": "",
+ }
+ ],
+ "index": current_index,
+ "type": "reasoning",
+ }
+ )
+ else:
+ block: dict = {"type": "reasoning", "reasoning": ""}
+ if chunk.summary_index > 0:
+ _advance(chunk.output_index, chunk.summary_index)
+ block["id"] = chunk.item_id
+ block["index"] = current_index
+ content.append(block)
elif chunk.type == "response.image_generation_call.partial_image":
# Partial images are not supported yet.
pass
@@ -4065,6 +4154,15 @@ def _convert_responses_chunk_to_generation_chunk(
AIMessageChunk,
_convert_to_v03_ai_message(message, has_reasoning=has_reasoning),
)
+ elif output_version == "v1":
+ message = cast(AIMessageChunk, _convert_to_v1_from_responses(message))
+ for content_block in message.content:
+ if (
+ isinstance(content_block, dict)
+ and content_block.get("index", -1) > current_index
+ ):
+ # blocks were added for v1
+ current_index = content_block["index"]
else:
pass
return (
diff --git a/libs/partners/openai/tests/cassettes/test_function_calling.yaml.gz b/libs/partners/openai/tests/cassettes/test_function_calling.yaml.gz
new file mode 100644
index 00000000000..197a8402cf6
Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_function_calling.yaml.gz differ
diff --git a/libs/partners/openai/tests/cassettes/test_parsed_pydantic_schema.yaml.gz b/libs/partners/openai/tests/cassettes/test_parsed_pydantic_schema.yaml.gz
new file mode 100644
index 00000000000..13c0b8896de
Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_parsed_pydantic_schema.yaml.gz differ
diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py
index 4b02676e2a8..0417bbc32ed 100644
--- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py
+++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py
@@ -473,6 +473,7 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="foo",
+ type="tool_call",
)
],
),
@@ -494,6 +495,7 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="bar",
+ type="tool_call",
)
],
),
diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py
index 0e23d0e3f06..527eece1241 100644
--- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py
+++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py
@@ -52,9 +52,11 @@ def _check_response(response: Optional[BaseMessage]) -> None:
assert response.response_metadata["service_tier"]
+@pytest.mark.default_cassette("test_web_search.yaml.gz")
@pytest.mark.vcr
-def test_web_search() -> None:
- llm = ChatOpenAI(model=MODEL_NAME, output_version="responses/v1")
+@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
+def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
+ llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
first_response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
@@ -110,7 +112,10 @@ def test_web_search() -> None:
for msg in [first_response, full, response]:
assert isinstance(msg, AIMessage)
block_types = [block["type"] for block in msg.content] # type: ignore[index]
- assert block_types == ["web_search_call", "text"]
+ if output_version == "responses/v1":
+ assert block_types == ["web_search_call", "text"]
+ else:
+ assert block_types == ["web_search_call", "web_search_result", "text"]
@pytest.mark.flaky(retries=3, delay=1)
@@ -141,13 +146,15 @@ async def test_web_search_async() -> None:
assert tool_output["type"] == "web_search_call"
-@pytest.mark.flaky(retries=3, delay=1)
-def test_function_calling() -> None:
+@pytest.mark.default_cassette("test_function_calling.yaml.gz")
+@pytest.mark.vcr
+@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
+def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
def multiply(x: int, y: int) -> int:
"""return x * y"""
return x * y
- llm = ChatOpenAI(model=MODEL_NAME)
+ llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}])
ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4"))
assert len(ai_msg.tool_calls) == 1
@@ -174,8 +181,15 @@ class FooDict(TypedDict):
response: str
-def test_parsed_pydantic_schema() -> None:
- llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
+@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
+@pytest.mark.vcr
+@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
+def test_parsed_pydantic_schema(
+ output_version: Literal["v0", "responses/v1", "v1"],
+) -> None:
+ llm = ChatOpenAI(
+ model=MODEL_NAME, use_responses_api=True, output_version=output_version
+ )
response = llm.invoke("how are ya", response_format=Foo)
parsed = Foo(**json.loads(response.text()))
assert parsed == response.additional_kwargs["parsed"]
@@ -297,8 +311,8 @@ def test_function_calling_and_structured_output() -> None:
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
@pytest.mark.vcr
-@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
-def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None:
+@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
+def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
llm = ChatOpenAI(
model="o4-mini", use_responses_api=True, output_version=output_version
)
@@ -358,27 +372,32 @@ def test_computer_calls() -> None:
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
- llm = ChatOpenAI(model=MODEL_NAME)
+ llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
tool = {
"type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
}
- response = llm.invoke("What is deep research by OpenAI?", tools=[tool])
+
+ input_message = {"role": "user", "content": "What is deep research by OpenAI?"}
+ response = llm.invoke([input_message], tools=[tool])
_check_response(response)
full: Optional[BaseMessageChunk] = None
- for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]):
+ for chunk in llm.stream([input_message], tools=[tool]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)
+ next_message = {"role": "user", "content": "Thank you."}
+ _ = llm.invoke([input_message, full, next_message])
+
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
@pytest.mark.vcr
-@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
+@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_stream_reasoning_summary(
- output_version: Literal["v0", "responses/v1"],
+ output_version: Literal["v0", "responses/v1", "v1"],
) -> None:
llm = ChatOpenAI(
model="o4-mini",
@@ -398,20 +417,39 @@ def test_stream_reasoning_summary(
if output_version == "v0":
reasoning = response_1.additional_kwargs["reasoning"]
assert set(reasoning.keys()) == {"id", "type", "summary"}
- else:
+ summary = reasoning["summary"]
+ assert isinstance(summary, list)
+ for block in summary:
+ assert isinstance(block, dict)
+ assert isinstance(block["type"], str)
+ assert isinstance(block["text"], str)
+ assert block["text"]
+ elif output_version == "responses/v1":
reasoning = next(
block
for block in response_1.content
if block["type"] == "reasoning" # type: ignore[index]
)
assert set(reasoning.keys()) == {"id", "type", "summary", "index"}
- summary = reasoning["summary"]
- assert isinstance(summary, list)
- for block in summary:
- assert isinstance(block, dict)
- assert isinstance(block["type"], str)
- assert isinstance(block["text"], str)
- assert block["text"]
+ summary = reasoning["summary"]
+ assert isinstance(summary, list)
+ for block in summary:
+ assert isinstance(block, dict)
+ assert isinstance(block["type"], str)
+ assert isinstance(block["text"], str)
+ assert block["text"]
+ else:
+ # v1
+ total_reasoning_blocks = 0
+ for block in response_1.content:
+ if block["type"] == "reasoning":
+ total_reasoning_blocks += 1
+ assert isinstance(block["id"], str) and block["id"].startswith("rs_")
+ assert isinstance(block["reasoning"], str)
+ assert isinstance(block["index"], int)
+ assert (
+ total_reasoning_blocks > 1
+ ) # This query typically generates multiple reasoning blocks
# Check we can pass back summaries
message_2 = {"role": "user", "content": "Thank you."}
@@ -419,9 +457,13 @@ def test_stream_reasoning_summary(
assert isinstance(response_2, AIMessage)
+@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
@pytest.mark.vcr
-def test_code_interpreter() -> None:
- llm = ChatOpenAI(model="o4-mini", use_responses_api=True)
+@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
+def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
+ llm = ChatOpenAI(
+ model="o4-mini", use_responses_api=True, output_version=output_version
+ )
llm_with_tools = llm.bind_tools(
[{"type": "code_interpreter", "container": {"type": "auto"}}]
)
@@ -430,15 +472,38 @@ def test_code_interpreter() -> None:
"content": "Write and run code to answer the question: what is 3^3?",
}
response = llm_with_tools.invoke([input_message])
+ assert isinstance(response, AIMessage)
_check_response(response)
- tool_outputs = response.additional_kwargs["tool_outputs"]
- assert tool_outputs
- assert any(output["type"] == "code_interpreter_call" for output in tool_outputs)
+ if output_version == "v0":
+ tool_outputs = [
+ item
+ for item in response.additional_kwargs["tool_outputs"]
+ if item["type"] == "code_interpreter_call"
+ ]
+ elif output_version == "responses/v1":
+ tool_outputs = [
+ item
+ for item in response.content
+ if isinstance(item, dict) and item["type"] == "code_interpreter_call"
+ ]
+ else:
+ # v1
+ tool_outputs = [
+ item
+ for item in response.content
+ if isinstance(item, dict) and item["type"] == "code_interpreter_call"
+ ]
+ code_interpreter_result = next(
+ item
+ for item in response.content
+ if isinstance(item, dict) and item["type"] == "code_interpreter_result"
+ )
+ assert tool_outputs
+ assert code_interpreter_result
+ assert len(tool_outputs) == 1
# Test streaming
# Use same container
- tool_outputs = response.additional_kwargs["tool_outputs"]
- assert len(tool_outputs) == 1
container_id = tool_outputs[0]["container_id"]
llm_with_tools = llm.bind_tools(
[{"type": "code_interpreter", "container": container_id}]
@@ -449,9 +514,32 @@ def test_code_interpreter() -> None:
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
- tool_outputs = full.additional_kwargs["tool_outputs"]
+ if output_version == "v0":
+ tool_outputs = [
+ item
+ for item in response.additional_kwargs["tool_outputs"]
+ if item["type"] == "code_interpreter_call"
+ ]
+ elif output_version == "responses/v1":
+ tool_outputs = [
+ item
+ for item in response.content
+ if isinstance(item, dict) and item["type"] == "code_interpreter_call"
+ ]
+ else:
+ code_interpreter_call = next(
+ item
+ for item in response.content
+ if isinstance(item, dict) and item["type"] == "code_interpreter_call"
+ )
+ code_interpreter_result = next(
+ item
+ for item in response.content
+ if isinstance(item, dict) and item["type"] == "code_interpreter_result"
+ )
+ assert code_interpreter_call
+ assert code_interpreter_result
assert tool_outputs
- assert any(output["type"] == "code_interpreter_call" for output in tool_outputs)
# Test we can pass back in
next_message = {"role": "user", "content": "Please add more comments to the code."}
@@ -546,10 +634,14 @@ def test_mcp_builtin_zdr() -> None:
_ = llm_with_tools.invoke([input_message, full, approval_message])
-@pytest.mark.vcr()
-def test_image_generation_streaming() -> None:
+@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
+@pytest.mark.vcr
+@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
+def test_image_generation_streaming(output_version: str) -> None:
"""Test image generation streaming."""
- llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
+ llm = ChatOpenAI(
+ model="gpt-4.1", use_responses_api=True, output_version=output_version
+ )
tool = {
"type": "image_generation",
# For testing purposes let's keep the quality low, so the test runs faster.
@@ -596,15 +688,37 @@ def test_image_generation_streaming() -> None:
# At the moment, the streaming API does not pick up annotations fully.
# So the following check is commented out.
# _check_response(complete_ai_message)
- tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
- assert set(tool_output.keys()).issubset(expected_keys)
+ if output_version == "v0":
+ assert complete_ai_message.additional_kwargs["tool_outputs"]
+ tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
+ assert set(tool_output.keys()).issubset(expected_keys)
+ elif output_version == "responses/v1":
+ tool_output = next(
+ block
+ for block in complete_ai_message.content
+ if isinstance(block, dict) and block["type"] == "image_generation_call"
+ )
+ assert set(tool_output.keys()).issubset(expected_keys)
+ else:
+ # v1
+ standard_keys = {"type", "base64", "id", "status", "index"}
+ tool_output = next(
+ block
+ for block in complete_ai_message.content
+ if isinstance(block, dict) and block["type"] == "image"
+ )
+ assert set(standard_keys).issubset(tool_output.keys())
-@pytest.mark.vcr()
-def test_image_generation_multi_turn() -> None:
+@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
+@pytest.mark.vcr
+@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
+def test_image_generation_multi_turn(output_version: str) -> None:
"""Test multi-turn editing of image generation by passing in history."""
# Test multi-turn
- llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
+ llm = ChatOpenAI(
+ model="gpt-4.1", use_responses_api=True, output_version=output_version
+ )
# Test invocation
tool = {
"type": "image_generation",
@@ -620,10 +734,41 @@ def test_image_generation_multi_turn() -> None:
{"role": "user", "content": "Draw a random short word in green font."}
]
ai_message = llm_with_tools.invoke(chat_history)
+ assert isinstance(ai_message, AIMessage)
_check_response(ai_message)
- tool_output = ai_message.additional_kwargs["tool_outputs"][0]
- # Example tool output for an image
+ expected_keys = {
+ "id",
+ "background",
+ "output_format",
+ "quality",
+ "result",
+ "revised_prompt",
+ "size",
+ "status",
+ "type",
+ }
+
+ if output_version == "v0":
+ tool_output = ai_message.additional_kwargs["tool_outputs"][0]
+ assert set(tool_output.keys()).issubset(expected_keys)
+ elif output_version == "responses/v1":
+ tool_output = next(
+ block
+ for block in ai_message.content
+ if isinstance(block, dict) and block["type"] == "image_generation_call"
+ )
+ assert set(tool_output.keys()).issubset(expected_keys)
+ else:
+ standard_keys = {"type", "base64", "id", "status"}
+ tool_output = next(
+ block
+ for block in ai_message.content
+ if isinstance(block, dict) and block["type"] == "image"
+ )
+ assert set(standard_keys).issubset(tool_output.keys())
+
+ # Example tool output for an image (v0)
# {
# "background": "opaque",
# "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8",
@@ -639,20 +784,6 @@ def test_image_generation_multi_turn() -> None:
# "result": # base64 encode image data
# }
- expected_keys = {
- "id",
- "background",
- "output_format",
- "quality",
- "result",
- "revised_prompt",
- "size",
- "status",
- "type",
- }
-
- assert set(tool_output.keys()).issubset(expected_keys)
-
chat_history.extend(
[
# AI message with tool output
@@ -669,6 +800,24 @@ def test_image_generation_multi_turn() -> None:
)
ai_message2 = llm_with_tools.invoke(chat_history)
+ assert isinstance(ai_message2, AIMessage)
_check_response(ai_message2)
- tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
- assert set(tool_output2.keys()).issubset(expected_keys)
+
+ if output_version == "v0":
+ tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
+ assert set(tool_output.keys()).issubset(expected_keys)
+ elif output_version == "responses/v1":
+ tool_output = next(
+ block
+ for block in ai_message2.content
+ if isinstance(block, dict) and block["type"] == "image_generation_call"
+ )
+ assert set(tool_output.keys()).issubset(expected_keys)
+ else:
+ standard_keys = {"type", "base64", "id", "status"}
+ tool_output = next(
+ block
+ for block in ai_message2.content
+ if isinstance(block, dict) and block["type"] == "image"
+ )
+ assert set(standard_keys).issubset(tool_output.keys())
diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py
index c4176711482..6a1f94c60db 100644
--- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py
+++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py
@@ -51,7 +51,11 @@ from langchain_openai import ChatOpenAI
from langchain_openai.chat_models._compat import (
_FUNCTION_CALL_IDS_MAP_KEY,
_convert_from_v03_ai_message,
+ _convert_from_v1_to_chat_completions,
+ _convert_from_v1_to_responses,
_convert_to_v03_ai_message,
+ _convert_to_v1_from_chat_completions,
+ _convert_to_v1_from_responses,
)
from langchain_openai.chat_models.base import (
_construct_lc_result_from_responses_api,
@@ -2297,7 +2301,7 @@ def test_mcp_tracing() -> None:
assert payload["tools"][0]["headers"]["Authorization"] == "Bearer PLACEHOLDER"
-def test_compat() -> None:
+def test_compat_responses_v1() -> None:
# Check compatibility with v0.3 message format
message_v03 = AIMessage(
content=[
@@ -2358,6 +2362,411 @@ def test_compat() -> None:
assert message_v03_output is not message_v03
+@pytest.mark.parametrize(
+ "message_v1, expected",
+ [
+ (
+ AIMessage(
+ [
+ {"type": "reasoning", "reasoning": "Reasoning text"},
+ {"type": "tool_call", "id": "call_123"},
+ {
+ "type": "text",
+ "text": "Hello, world!",
+ "annotations": [
+ {"type": "url_citation", "url": "https://example.com"}
+ ],
+ },
+ ],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ }
+ ],
+ id="chatcmpl-123",
+ response_metadata={"foo": "bar"},
+ ),
+ AIMessage(
+ [{"type": "text", "text": "Hello, world!"}],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ }
+ ],
+ id="chatcmpl-123",
+ response_metadata={"foo": "bar"},
+ ),
+ )
+ ],
+)
+def test_convert_from_v1_to_chat_completions(
+ message_v1: AIMessage, expected: AIMessage
+) -> None:
+ result = _convert_from_v1_to_chat_completions(message_v1)
+ assert result == expected
+
+ # Check no mutation
+ assert message_v1 != result
+
+
+@pytest.mark.parametrize(
+ "message_chat_completions, expected",
+ [
+ (
+ AIMessage(
+ "Hello, world!", id="chatcmpl-123", response_metadata={"foo": "bar"}
+ ),
+ AIMessage(
+ [{"type": "text", "text": "Hello, world!"}],
+ id="chatcmpl-123",
+ response_metadata={"foo": "bar"},
+ ),
+ ),
+ (
+ AIMessage(
+ [{"type": "text", "text": "Hello, world!"}],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ }
+ ],
+ id="chatcmpl-123",
+ response_metadata={"foo": "bar"},
+ ),
+ AIMessage(
+ [
+ {"type": "text", "text": "Hello, world!"},
+ {"type": "tool_call", "id": "call_123"},
+ ],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ }
+ ],
+ id="chatcmpl-123",
+ response_metadata={"foo": "bar"},
+ ),
+ ),
+ (
+ AIMessage(
+ "",
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ }
+ ],
+ id="chatcmpl-123",
+ response_metadata={"foo": "bar"},
+ additional_kwargs={"tool_calls": [{"foo": "bar"}]},
+ ),
+ AIMessage(
+ [{"type": "tool_call", "id": "call_123"}],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ }
+ ],
+ id="chatcmpl-123",
+ response_metadata={"foo": "bar"},
+ ),
+ ),
+ ],
+)
+def test_convert_to_v1_from_chat_completions(
+ message_chat_completions: AIMessage, expected: AIMessage
+) -> None:
+ result = _convert_to_v1_from_chat_completions(message_chat_completions)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ "message_v1, expected",
+ [
+ (
+ AIMessage(
+ [
+ {"type": "reasoning", "id": "abc123"},
+ {"type": "reasoning", "id": "abc234", "reasoning": "foo "},
+ {"type": "reasoning", "id": "abc234", "reasoning": "bar"},
+ {"type": "tool_call", "id": "call_123"},
+ {
+ "type": "tool_call",
+ "id": "call_234",
+ "name": "get_weather_2",
+ "arguments": '{"location": "New York"}',
+ "item_id": "fc_123",
+ },
+ {"type": "text", "text": "Hello "},
+ {
+ "type": "text",
+ "text": "world",
+ "annotations": [
+ {"type": "url_citation", "url": "https://example.com"},
+ {
+ "type": "document_citation",
+ "title": "my doc",
+ "index": 1,
+ "file_id": "file_123",
+ },
+ {
+ "type": "non_standard_annotation",
+ "value": {"bar": "baz"},
+ },
+ ],
+ },
+ {"type": "image", "base64": "...", "id": "img_123"},
+ {
+ "type": "non_standard",
+ "value": {"type": "something_else", "foo": "bar"},
+ },
+ ],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ },
+ {
+ # Make values different to check we pull from content when
+ # available
+ "type": "tool_call",
+ "id": "call_234",
+ "name": "get_weather_3",
+ "args": {"location": "Boston"},
+ },
+ ],
+ id="resp123",
+ response_metadata={"foo": "bar"},
+ ),
+ AIMessage(
+ [
+ {"type": "reasoning", "id": "abc123", "summary": []},
+ {
+ "type": "reasoning",
+ "id": "abc234",
+ "summary": [
+ {"type": "summary_text", "text": "foo "},
+ {"type": "summary_text", "text": "bar"},
+ ],
+ },
+ {
+ "type": "function_call",
+ "call_id": "call_123",
+ "name": "get_weather",
+ "arguments": '{"location": "San Francisco"}',
+ },
+ {
+ "type": "function_call",
+ "call_id": "call_234",
+ "name": "get_weather_2",
+ "arguments": '{"location": "New York"}',
+ "id": "fc_123",
+ },
+ {"type": "text", "text": "Hello "},
+ {
+ "type": "text",
+ "text": "world",
+ "annotations": [
+ {"type": "url_citation", "url": "https://example.com"},
+ {
+ "type": "file_citation",
+ "filename": "my doc",
+ "index": 1,
+ "file_id": "file_123",
+ },
+ {"bar": "baz"},
+ ],
+ },
+ {"type": "image_generation_call", "id": "img_123", "result": "..."},
+ {"type": "something_else", "foo": "bar"},
+ ],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ },
+ {
+ # Make values different to check we pull from content when
+ # available
+ "type": "tool_call",
+ "id": "call_234",
+ "name": "get_weather_3",
+ "args": {"location": "Boston"},
+ },
+ ],
+ id="resp123",
+ response_metadata={"foo": "bar"},
+ ),
+ )
+ ],
+)
+def test_convert_from_v1_to_responses(
+ message_v1: AIMessage, expected: AIMessage
+) -> None:
+ result = _convert_from_v1_to_responses(message_v1)
+ assert result == expected
+
+ # Check no mutation
+ assert message_v1 != result
+
+
+@pytest.mark.parametrize(
+ "message_responses, expected",
+ [
+ (
+ AIMessage(
+ [
+ {"type": "reasoning", "id": "abc123"},
+ {
+ "type": "reasoning",
+ "id": "abc234",
+ "summary": [
+ {"type": "summary_text", "text": "foo "},
+ {"type": "summary_text", "text": "bar"},
+ ],
+ },
+ {
+ "type": "function_call",
+ "call_id": "call_123",
+ "name": "get_weather",
+ "arguments": '{"location": "San Francisco"}',
+ },
+ {
+ "type": "function_call",
+ "call_id": "call_234",
+ "name": "get_weather_2",
+ "arguments": '{"location": "New York"}',
+ "id": "fc_123",
+ },
+ {"type": "text", "text": "Hello "},
+ {
+ "type": "text",
+ "text": "world",
+ "annotations": [
+ {"type": "url_citation", "url": "https://example.com"},
+ {
+ "type": "file_citation",
+ "filename": "my doc",
+ "index": 1,
+ "file_id": "file_123",
+ },
+ {"bar": "baz"},
+ ],
+ },
+ {"type": "image_generation_call", "id": "img_123", "result": "..."},
+ {"type": "something_else", "foo": "bar"},
+ ],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ },
+ {
+ # Make values different to check we pull from content when
+ # available
+ "type": "tool_call",
+ "id": "call_234",
+ "name": "get_weather_3",
+ "args": {"location": "Boston"},
+ },
+ ],
+ id="resp123",
+ response_metadata={"foo": "bar"},
+ ),
+ AIMessage(
+ [
+ {"type": "reasoning", "id": "abc123"},
+ {"type": "reasoning", "id": "abc234", "reasoning": "foo "},
+ {"type": "reasoning", "id": "abc234", "reasoning": "bar"},
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "arguments": '{"location": "San Francisco"}',
+ },
+ {
+ "type": "tool_call",
+ "id": "call_234",
+ "name": "get_weather_2",
+ "arguments": '{"location": "New York"}',
+ "item_id": "fc_123",
+ },
+ {"type": "text", "text": "Hello "},
+ {
+ "type": "text",
+ "text": "world",
+ "annotations": [
+ {"type": "url_citation", "url": "https://example.com"},
+ {
+ "type": "document_citation",
+ "title": "my doc",
+ "index": 1,
+ "file_id": "file_123",
+ },
+ {
+ "type": "non_standard_annotation",
+ "value": {"bar": "baz"},
+ },
+ ],
+ },
+ {"type": "image", "base64": "...", "id": "img_123"},
+ {
+ "type": "non_standard",
+ "value": {"type": "something_else", "foo": "bar"},
+ },
+ ],
+ tool_calls=[
+ {
+ "type": "tool_call",
+ "id": "call_123",
+ "name": "get_weather",
+ "args": {"location": "San Francisco"},
+ },
+ {
+ # Make values different to check we pull from content when
+ # available
+ "type": "tool_call",
+ "id": "call_234",
+ "name": "get_weather_3",
+ "args": {"location": "Boston"},
+ },
+ ],
+ id="resp123",
+ response_metadata={"foo": "bar"},
+ ),
+ )
+ ],
+)
+def test_convert_to_v1_from_responses(
+ message_responses: AIMessage, expected: AIMessage
+) -> None:
+ result = _convert_to_v1_from_responses(message_responses)
+ assert result == expected
+
+
def test_get_last_messages() -> None:
messages: list[BaseMessage] = [HumanMessage("Hello")]
last_messages, previous_response_id = _get_last_messages(messages)
diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py
index 370adcd1f1a..bee9e6fe095 100644
--- a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py
+++ b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py
@@ -1,6 +1,7 @@
from typing import Any, Optional
from unittest.mock import MagicMock, patch
+import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from openai.types.responses import (
ResponseCompletedEvent,
@@ -610,8 +611,97 @@ def _strip_none(obj: Any) -> Any:
return obj
-def test_responses_stream() -> None:
- llm = ChatOpenAI(model="o4-mini", output_version="responses/v1")
+@pytest.mark.parametrize(
+ "output_version, expected_content",
+ [
+ (
+ "responses/v1",
+ [
+ {
+ "id": "rs_123",
+ "summary": [
+ {
+ "index": 0,
+ "type": "summary_text",
+ "text": "reasoning block one",
+ },
+ {
+ "index": 1,
+ "type": "summary_text",
+ "text": "another reasoning block",
+ },
+ ],
+ "type": "reasoning",
+ "index": 0,
+ },
+ {"type": "text", "text": "text block one", "index": 1, "id": "msg_123"},
+ {
+ "type": "text",
+ "text": "another text block",
+ "index": 2,
+ "id": "msg_123",
+ },
+ {
+ "id": "rs_234",
+ "summary": [
+ {"index": 0, "type": "summary_text", "text": "more reasoning"},
+ {
+ "index": 1,
+ "type": "summary_text",
+ "text": "still more reasoning",
+ },
+ ],
+ "type": "reasoning",
+ "index": 3,
+ },
+ {"type": "text", "text": "more", "index": 4, "id": "msg_234"},
+ {"type": "text", "text": "text", "index": 5, "id": "msg_234"},
+ ],
+ ),
+ (
+ "v1",
+ [
+ {
+ "type": "reasoning",
+ "reasoning": "reasoning block one",
+ "id": "rs_123",
+ "index": 0,
+ },
+ {
+ "type": "reasoning",
+ "reasoning": "another reasoning block",
+ "id": "rs_123",
+ "index": 1,
+ },
+ {"type": "text", "text": "text block one", "index": 2, "id": "msg_123"},
+ {
+ "type": "text",
+ "text": "another text block",
+ "index": 3,
+ "id": "msg_123",
+ },
+ {
+ "type": "reasoning",
+ "reasoning": "more reasoning",
+ "id": "rs_234",
+ "index": 4,
+ },
+ {
+ "type": "reasoning",
+ "reasoning": "still more reasoning",
+ "id": "rs_234",
+ "index": 5,
+ },
+ {"type": "text", "text": "more", "index": 6, "id": "msg_234"},
+ {"type": "text", "text": "text", "index": 7, "id": "msg_234"},
+ ],
+ ),
+ ],
+)
+def test_responses_stream(output_version: str, expected_content: list[dict]) -> None:
+ llm = ChatOpenAI(
+ model="o4-mini", use_responses_api=True, output_version=output_version
+ )
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
@@ -620,36 +710,14 @@ def test_responses_stream() -> None:
mock_client.responses.create = mock_create
full: Optional[BaseMessageChunk] = None
+ chunks = []
with patch.object(llm, "root_client", mock_client):
for chunk in llm.stream("test"):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
- assert isinstance(full, AIMessageChunk)
+ chunks.append(chunk)
- expected_content = [
- {
- "id": "rs_123",
- "summary": [
- {"index": 0, "type": "summary_text", "text": "reasoning block one"},
- {"index": 1, "type": "summary_text", "text": "another reasoning block"},
- ],
- "type": "reasoning",
- "index": 0,
- },
- {"type": "text", "text": "text block one", "index": 1, "id": "msg_123"},
- {"type": "text", "text": "another text block", "index": 2, "id": "msg_123"},
- {
- "id": "rs_234",
- "summary": [
- {"index": 0, "type": "summary_text", "text": "more reasoning"},
- {"index": 1, "type": "summary_text", "text": "still more reasoning"},
- ],
- "type": "reasoning",
- "index": 3,
- },
- {"type": "text", "text": "more", "index": 4, "id": "msg_234"},
- {"type": "text", "text": "text", "index": 5, "id": "msg_234"},
- ]
+ assert isinstance(full, AIMessageChunk)
assert full.content == expected_content
assert full.additional_kwargs == {}
assert full.id == "resp_123"