From ef9b5a9e186ef4085d66519e241f1aed5d9724db Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 28 Jul 2025 10:47:26 -0400 Subject: [PATCH] add back standard_outputs --- libs/core/langchain_core/callbacks/base.py | 41 +- libs/core/langchain_core/callbacks/manager.py | 58 +- .../langchain_core/language_models/_utils.py | 6 +- .../langchain_core/language_models/base.py | 5 +- .../language_models/chat_models.py | 33 + .../language_models/v1/__init__.py | 1 + .../language_models/v1/chat_models.py | 909 ++++++++++++++++++ libs/core/langchain_core/messages/__init__.py | 54 ++ .../langchain_core/messages/content_blocks.py | 812 ++++++++++++++-- libs/core/langchain_core/messages/modifier.py | 2 +- libs/core/langchain_core/messages/tool.py | 92 +- libs/core/langchain_core/messages/utils.py | 213 +++- libs/core/langchain_core/messages/v1.py | 568 +++++++++++ .../output_parsers/transform.py | 4 +- libs/core/langchain_core/runnables/base.py | 16 +- libs/core/langchain_core/runnables/config.py | 4 +- libs/core/langchain_core/runnables/graph.py | 6 +- .../langchain_core/utils/function_calling.py | 2 +- libs/core/pyproject.toml | 2 + .../unit_tests/fake/test_fake_chat_model.py | 14 +- .../language_models/chat_models/test_cache.py | 5 +- .../tests/unit_tests/messages/test_imports.py | 18 + .../tests/unit_tests/messages/test_utils.py | 31 +- .../prompts/__snapshots__/test_chat.ambr | 12 + .../runnables/__snapshots__/test_graph.ambr | 6 + .../__snapshots__/test_runnable.ambr | 48 + libs/core/tests/unit_tests/test_messages.py | 223 ++++- .../langchain/agents/output_parsers/tools.py | 7 +- .../format_scratchpad/test_openai_tools.py | 7 +- .../tests/unit_tests/agents/test_agent.py | 2 +- .../tests/unit_tests/chat_models/test_base.py | 1 + .../langchain_openai/chat_models/_compat.py | 440 ++++++++- .../langchain_openai/chat_models/base.py | 132 ++- .../cassettes/test_function_calling.yaml.gz | Bin 0 -> 7912 bytes .../test_parsed_pydantic_schema.yaml.gz | Bin 0 -> 4616 bytes .../chat_models/test_base.py | 2 + .../chat_models/test_responses_api.py | 265 +++-- .../tests/unit_tests/chat_models/test_base.py | 411 +++++++- .../chat_models/test_responses_stream.py | 122 ++- 39 files changed, 4201 insertions(+), 373 deletions(-) create mode 100644 libs/core/langchain_core/language_models/v1/__init__.py create mode 100644 libs/core/langchain_core/language_models/v1/chat_models.py create mode 100644 libs/core/langchain_core/messages/v1.py create mode 100644 libs/partners/openai/tests/cassettes/test_function_calling.yaml.gz create mode 100644 libs/partners/openai/tests/cassettes/test_parsed_pydantic_schema.yaml.gz diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 5365fcb9ef1..f9c6f75f307 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union from typing_extensions import Self +from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1 + if TYPE_CHECKING: from collections.abc import Sequence from uuid import UUID @@ -64,9 +66,11 @@ class LLMManagerMixin: def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunk], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, @@ -75,8 +79,8 @@ class LLMManagerMixin: Args: token (str): The new token. - chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, - containing content and other information. + chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new + generated chunk, containing content and other information. run_id (UUID): The run ID. This is the ID of the current run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run. kwargs (Any): Additional keyword arguments. @@ -84,7 +88,7 @@ class LLMManagerMixin: def on_llm_end( self, - response: LLMResult, + response: Union[LLMResult, AIMessage], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -93,7 +97,7 @@ class LLMManagerMixin: """Run when LLM ends running. Args: - response (LLMResult): The response which was generated. + response (LLMResult | AIMessage): The response which was generated. run_id (UUID): The run ID. This is the ID of the current run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run. kwargs (Any): Additional keyword arguments. @@ -261,7 +265,7 @@ class CallbackManagerMixin: def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -439,6 +443,9 @@ class BaseCallbackHandler( run_inline: bool = False """Whether to run the callback inline.""" + accepts_new_messages: bool = False + """Whether the callback accepts new message format.""" + @property def ignore_llm(self) -> bool: """Whether to ignore LLM callbacks.""" @@ -509,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -538,9 +545,11 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunk], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[list[str]] = None, @@ -550,8 +559,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): Args: token (str): The new token. - chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, - containing content and other information. + chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new + generated chunk, containing content and other information. run_id (UUID): The run ID. This is the ID of the current run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run. tags (Optional[list[str]]): The tags. @@ -560,7 +569,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_end( self, - response: LLMResult, + response: Union[LLMResult, AIMessage], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -570,7 +579,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): """Run when LLM ends running. Args: - response (LLMResult): The response which was generated. + response (LLMResult | AIMessage): The response which was generated. run_id (UUID): The run ID. This is the ID of the current run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run. tags (Optional[list[str]]): The tags. @@ -594,8 +603,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): parent_run_id: The parent run ID. This is the ID of the parent run. tags: The tags. kwargs (Any): Additional keyword arguments. - - response (LLMResult): The response which was generated before - the error occurred. + - response (LLMResult | AIMessage): The response which was generated + before the error occurred. """ async def on_chain_start( diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 56fc1bb67ba..bca2d800696 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -11,6 +11,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager from contextvars import copy_context +from dataclasses import is_dataclass from typing import ( TYPE_CHECKING, Any, @@ -37,6 +38,8 @@ from langchain_core.callbacks.base import ( ) from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.messages.v1 import AIMessage, AIMessageChunk +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult from langchain_core.tracers.schemas import Run from langchain_core.utils.env import env_var_is_set @@ -47,7 +50,7 @@ if TYPE_CHECKING: from langchain_core.agents import AgentAction, AgentFinish from langchain_core.documents import Document - from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult + from langchain_core.outputs import GenerationChunk from langchain_core.runnables.config import RunnableConfig logger = logging.getLogger(__name__) @@ -243,6 +246,22 @@ def shielded(func: Func) -> Func: return cast("Func", wrapped) +def _convert_llm_events( + event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> None: + if event_name == "on_chat_model_start" and isinstance(args[1], list): + for idx, item in enumerate(args[1]): + if is_dataclass(item): + args[1][idx] = item # convert to old message + elif event_name == "on_llm_new_token" and is_dataclass(args[0]): + kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0]) + args[0] = args[0].text + elif event_name == "on_llm_end" and is_dataclass(args[0]): + args[0] = LLMResult( + generations=[[ChatGeneration(text=args[0].text, message=args[0])]] + ) + + def handle_event( handlers: list[BaseCallbackHandler], event_name: str, @@ -271,6 +290,8 @@ def handle_event( if ignore_condition_name is None or not getattr( handler, ignore_condition_name ): + if not handler.accepts_new_messages: + _convert_llm_events(event_name, args, kwargs) event = getattr(handler, event_name)(*args, **kwargs) if asyncio.iscoroutine(event): coros.append(event) @@ -365,6 +386,8 @@ async def _ahandle_event_for_handler( ) -> None: try: if ignore_condition_name is None or not getattr(handler, ignore_condition_name): + if not handler.accepts_new_messages: + _convert_llm_events(event_name, args, kwargs) event = getattr(handler, event_name) if asyncio.iscoroutinefunction(event): await event(*args, **kwargs) @@ -672,9 +695,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunk], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token. @@ -699,11 +724,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): **kwargs, ) - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None: """Run when LLM ends running. Args: - response (LLMResult): The LLM result. + response (LLMResult | AIMessage): The LLM result. **kwargs (Any): Additional keyword arguments. """ if not self.handlers: @@ -729,8 +754,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. kwargs (Any): Additional keyword arguments. - - response (LLMResult): The response which was generated before - the error occurred. + - response (LLMResult | AIMessage): The response which was generated + before the error occurred. """ if not self.handlers: return @@ -768,9 +793,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): async def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunk], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk] + ] = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token. @@ -796,11 +823,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): ) @shielded - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + async def on_llm_end( + self, response: Union[LLMResult, AIMessage], **kwargs: Any + ) -> None: """Run when LLM ends running. Args: - response (LLMResult): The LLM result. + response (LLMResult | AIMessage): The LLM result. **kwargs (Any): Additional keyword arguments. """ if not self.handlers: @@ -827,11 +856,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. kwargs (Any): Additional keyword arguments. - - response (LLMResult): The response which was generated before - the error occurred. - - - + - response (LLMResult | AIMessage): The response which was generated + before the error occurred. """ if not self.handlers: return diff --git a/libs/core/langchain_core/language_models/_utils.py b/libs/core/langchain_core/language_models/_utils.py index 883f8c855ea..cc5d39a18c3 100644 --- a/libs/core/langchain_core/language_models/_utils.py +++ b/libs/core/langchain_core/language_models/_utils.py @@ -1,3 +1,4 @@ +import copy import re from collections.abc import Sequence from typing import Optional @@ -128,7 +129,10 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]: and _is_openai_data_block(block) ): if formatted_message is message: - formatted_message = message.model_copy() + if isinstance(message, BaseMessage): + formatted_message = message.model_copy() + else: + formatted_message = copy.copy(message) # Also shallow-copy content formatted_message.content = list(formatted_message.content) diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index a9e7e4e64cc..1fef01a6859 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -28,6 +28,7 @@ from langchain_core.messages import ( MessageLikeRepresentation, get_buffer_string, ) +from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.prompt_values import PromptValue from langchain_core.runnables import Runnable, RunnableSerializable from langchain_core.utils import get_pydantic_field_names @@ -85,7 +86,9 @@ def _get_token_ids_default_method(text: str) -> list[int]: LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]] LanguageModelOutput = Union[BaseMessage, str] LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput] -LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str) +LanguageModelOutputVar = TypeVar( + "LanguageModelOutputVar", BaseMessage, str, AIMessageV1 +) def _get_verbosity() -> bool: diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index d37ca232052..e96003d26fe 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -58,7 +58,9 @@ from langchain_core.messages import ( is_data_content_block, message_chunk_to_message, ) +from langchain_core.messages import content_blocks as types from langchain_core.messages.ai import _LC_ID_PREFIX +from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, @@ -220,6 +222,23 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) -> return ls_structured_output_format_dict +def _convert_to_v1(message: AIMessage) -> AIMessageV1: + """Best-effort conversion of a V0 AIMessage to V1.""" + if isinstance(message.content, str): + content: list[types.ContentBlock] = [] + if message.content: + content = [{"type": "text", "text": message.content}] + + for tool_call in message.tool_calls: + content.append(tool_call) + + return AIMessageV1( + content=content, + usage_metadata=message.usage_metadata, + response_metadata=message.response_metadata, + ) + + class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): """Base class for chat models. @@ -328,6 +347,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): does not properly support streaming. """ + output_version: str = "v0" + """Version of AIMessage output format to use. + + This field is used to roll-out new output formats for chat model AIMessages + in a backwards-compatible way. + + ``'v1'`` standardizes output format using a list of typed ContentBlock dicts. We + recommend this for new applications. + + All chat models currently support the default of ``"v0"``. + + .. versionadded:: 0.4 + """ + @model_validator(mode="before") @classmethod def raise_deprecation(cls, values: dict) -> Any: diff --git a/libs/core/langchain_core/language_models/v1/__init__.py b/libs/core/langchain_core/language_models/v1/__init__.py new file mode 100644 index 00000000000..eaffaa213ba --- /dev/null +++ b/libs/core/langchain_core/language_models/v1/__init__.py @@ -0,0 +1 @@ +"""LangChain v1.0 chat models.""" diff --git a/libs/core/langchain_core/language_models/v1/chat_models.py b/libs/core/langchain_core/language_models/v1/chat_models.py new file mode 100644 index 00000000000..437fb876496 --- /dev/null +++ b/libs/core/langchain_core/language_models/v1/chat_models.py @@ -0,0 +1,909 @@ +"""Chat models for conversational AI.""" + +from __future__ import annotations + +import copy +import typing +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterator, Sequence +from operator import itemgetter +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Optional, + Union, + cast, +) + +from pydantic import ( + BaseModel, + ConfigDict, + Field, +) +from typing_extensions import override + +from langchain_core.callbacks import ( + AsyncCallbackManager, + AsyncCallbackManagerForLLMRun, + CallbackManager, + CallbackManagerForLLMRun, +) +from langchain_core.language_models._utils import _normalize_messages +from langchain_core.language_models.base import ( + BaseLanguageModel, + LangSmithParams, + LanguageModelInput, +) +from langchain_core.messages import ( + AIMessage, + BaseMessage, + convert_to_openai_image_block, + is_data_content_block, +) +from langchain_core.messages.utils import convert_to_messages_v1 +from langchain_core.messages.v1 import AIMessage as AIMessageV1 +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 +from langchain_core.messages.v1 import MessageV1, add_ai_message_chunks +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, +) +from langchain_core.prompt_values import PromptValue +from langchain_core.rate_limiters import BaseRateLimiter +from langchain_core.runnables import RunnableMap, RunnablePassthrough +from langchain_core.runnables.config import ensure_config, run_in_executor +from langchain_core.tracers._streaming import _StreamingCallbackHandler +from langchain_core.utils.function_calling import ( + convert_to_json_schema, + convert_to_openai_tool, +) +from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass + +if TYPE_CHECKING: + from langchain_core.output_parsers.base import OutputParserLike + from langchain_core.runnables import Runnable, RunnableConfig + from langchain_core.tools import BaseTool + + +def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]: + if hasattr(error, "response"): + response = error.response + metadata: dict = {} + if hasattr(response, "headers"): + try: + metadata["headers"] = dict(response.headers) + except Exception: + metadata["headers"] = None + if hasattr(response, "status_code"): + metadata["status_code"] = response.status_code + if hasattr(error, "request_id"): + metadata["request_id"] = error.request_id + generations = [AIMessageV1(content=[], response_metadata=metadata)] + else: + generations = [] + + return generations + + +def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]: + """Format messages for tracing in on_chat_model_start. + + - Update image content blocks to OpenAI Chat Completions format (backward + compatibility). + - Add "type" key to content blocks that have a single key. + + Args: + messages: List of messages to format. + + Returns: + List of messages formatted for tracing. + """ + messages_to_trace = [] + for message in messages: + message_to_trace = message + for idx, block in enumerate(message.content): + # Update image content blocks to OpenAI # Chat Completions format. + if ( + block["type"] == "image" + and is_data_content_block(block) + and block.get("source_type") != "id" + ): + if message_to_trace is message: + # Shallow copy + message_to_trace = copy.copy(message) + message_to_trace.content = list(message_to_trace.content) + + message_to_trace.content[idx] = convert_to_openai_image_block(block) + else: + pass + messages_to_trace.append(message_to_trace) + + return messages_to_trace + + +def generate_from_stream(stream: Iterator[AIMessageChunkV1]) -> AIMessageV1: + """Generate from a stream. + + Args: + stream: Iterator of AIMessageChunkV1. + + Returns: + AIMessageV1: aggregated message. + """ + generation = next(stream, None) + if generation: + generation += list(stream) + if generation is None: + msg = "No generations found in stream." + raise ValueError(msg) + return generation.to_message() + + +async def agenerate_from_stream( + stream: AsyncIterator[AIMessageChunkV1], +) -> AIMessageV1: + """Async generate from a stream. + + Args: + stream: Iterator of AIMessageChunkV1. + + Returns: + AIMessageV1: aggregated message. + """ + chunks = [chunk async for chunk in stream] + return await run_in_executor(None, generate_from_stream, iter(chunks)) + + +def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) -> dict: + if ls_structured_output_format: + try: + ls_structured_output_format_dict = { + "ls_structured_output_format": { + "kwargs": ls_structured_output_format.get("kwargs", {}), + "schema": convert_to_json_schema( + ls_structured_output_format["schema"] + ), + } + } + except ValueError: + ls_structured_output_format_dict = {} + else: + ls_structured_output_format_dict = {} + + return ls_structured_output_format_dict + + +class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): + """Base class for chat models. + + Key imperative methods: + Methods that actually call the underlying model. + + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | Method | Input | Output | Description | + +===========================+================================================================+=====================================================================+==================================================================================================+ + | `invoke` | str | list[dict | tuple | BaseMessage] | PromptValue | BaseMessage | A single chat model call. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `ainvoke` | ''' | BaseMessage | Defaults to running invoke in an async executor. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `stream` | ''' | Iterator[BaseMessageChunk] | Defaults to yielding output of invoke. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `astream` | ''' | AsyncIterator[BaseMessageChunk] | Defaults to yielding output of ainvoke. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `astream_events` | ''' | AsyncIterator[StreamEvent] | Event types: 'on_chat_model_start', 'on_chat_model_stream', 'on_chat_model_end'. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `batch` | list['''] | list[BaseMessage] | Defaults to running invoke in concurrent threads. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `abatch` | list['''] | list[BaseMessage] | Defaults to running ainvoke in concurrent threads. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `batch_as_completed` | list['''] | Iterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running invoke in concurrent threads. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + | `abatch_as_completed` | list['''] | AsyncIterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running ainvoke in concurrent threads. | + +---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ + + This table provides a brief overview of the main imperative methods. Please see the base Runnable reference for full documentation. + + Key declarative methods: + Methods for creating another Runnable using the ChatModel. + + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | Method | Description | + +==================================+===========================================================================================================+ + | `bind_tools` | Create ChatModel that can call tools. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `with_structured_output` | Create wrapper that structures model output using schema. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `with_retry` | Create wrapper that retries model calls on failure. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `with_fallbacks` | Create wrapper that falls back to other models on failure. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `configurable_fields` | Specify init args of the model that can be configured at runtime via the RunnableConfig. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + | `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the RunnableConfig. | + +----------------------------------+-----------------------------------------------------------------------------------------------------------+ + + This table provides a brief overview of the main declarative methods. Please see the reference for each method for full documentation. + + Creating custom chat model: + Custom chat model implementations should inherit from this class. + Please reference the table below for information about which + methods and properties are required or optional for implementations. + + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | Method/Property | Description | Required/Optional | + +==================================+====================================================================+===================+ + | `_generate` | Use to generate a chat result from a prompt | Required | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_llm_type` (property) | Used to uniquely identify the type of the model. Used for logging. | Required | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_identifying_params` (property) | Represent model parameterization for tracing purposes. | Optional | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_stream` | Use to implement streaming | Optional | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_agenerate` | Use to implement a native async method | Optional | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + | `_astream` | Use to implement async version of `_stream` | Optional | + +----------------------------------+--------------------------------------------------------------------+-------------------+ + + Follow the guide for more information on how to implement a custom Chat Model: + [Guide](https://python.langchain.com/docs/how_to/custom_chat_model/). + + """ # noqa: E501 + + rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True) + "An optional rate limiter to use for limiting the number of requests." + + disable_streaming: Union[bool, Literal["tool_calling"]] = False + """Whether to disable streaming for this model. + + If streaming is bypassed, then ``stream()``/``astream()``/``astream_events()`` will + defer to ``invoke()``/``ainvoke()``. + + - If True, will always bypass streaming case. + - If ``'tool_calling'``, will bypass streaming case only when the model is called + with a ``tools`` keyword argument. In other words, LangChain will automatically + switch to non-streaming behavior (``invoke()``) only when the tools argument is + provided. This offers the best of both worlds. + - If False (default), will always use streaming case if available. + + The main reason for this flag is that code might be written using ``.stream()`` and + a user may want to swap out a given model for another model whose the implementation + does not properly support streaming. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + # --- Runnable methods --- + + @property + @override + def OutputType(self) -> Any: + """Get the output type for this runnable.""" + return AIMessageV1 + + def _convert_input(self, model_input: LanguageModelInput) -> list[MessageV1]: + if isinstance(model_input, PromptValue): + return model_input.to_messages(output_version="v1") + if isinstance(model_input, str): + return [HumanMessageV1(content=model_input)] + if isinstance(model_input, Sequence): + return convert_to_messages_v1(model_input) + msg = ( + f"Invalid input type {type(model_input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + raise ValueError(msg) + + def _should_stream( + self, + *, + async_api: bool, + run_manager: Optional[ + Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] + ] = None, + **kwargs: Any, + ) -> bool: + """Determine if a given model call should hit the streaming API.""" + sync_not_implemented = type(self)._stream == BaseChatModelV1._stream # noqa: SLF001 + async_not_implemented = type(self)._astream == BaseChatModelV1._astream # noqa: SLF001 + + # Check if streaming is implemented. + if (not async_api) and sync_not_implemented: + return False + # Note, since async falls back to sync we check both here. + if async_api and async_not_implemented and sync_not_implemented: + return False + + # Check if streaming has been disabled on this instance. + if self.disable_streaming is True: + return False + # We assume tools are passed in via "tools" kwarg in all models. + if self.disable_streaming == "tool_calling" and kwargs.get("tools"): + return False + + # Check if a runtime streaming flag has been passed in. + if "stream" in kwargs: + return kwargs["stream"] + + # Check if any streaming callback handlers have been passed in. + handlers = run_manager.handlers if run_manager else [] + return any(isinstance(h, _StreamingCallbackHandler) for h in handlers) + + @override + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AIMessageV1: + config = ensure_config(config) + messages = self._convert_input(input) + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(**kwargs) + options = {**kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + ) + (run_manager,) = callback_manager.on_chat_model_start( + {}, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=config.get("run_name"), + run_id=config.pop("run_id", None), + batch_size=1, + ) + + if self.rate_limiter: + self.rate_limiter.acquire(blocking=True) + + input_messages = _normalize_messages(messages) + + if self._should_stream(async_api=False, **kwargs): + chunks: list[AIMessageChunkV1] = [] + try: + for msg in self._stream(input_messages, **kwargs): + run_manager.on_llm_new_token(msg) + chunks.append(msg) + except BaseException as e: + run_manager.on_llm_error(e, response=_generate_response_from_error(e)) + raise + msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + else: + try: + msg = self._invoke(input_messages, **kwargs) + except BaseException as e: + run_manager.on_llm_error(e) + raise + + run_manager.on_llm_end(msg) + return msg + + @override + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AIMessage: + config = ensure_config(config) + messages = self._convert_input(input) + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(**kwargs) + options = {**kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + ) + (run_manager,) = await callback_manager.on_chat_model_start( + {}, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=config.get("run_name"), + run_id=config.pop("run_id", None), + batch_size=1, + ) + + if self.rate_limiter: + await self.rate_limiter.aacquire(blocking=True) + + # TODO: type openai image, audio, file types and permit in MessageV1 + input_messages = _normalize_messages(messages) # type: ignore[arg-type] + + if self._should_stream(async_api=True, **kwargs): + chunks: list[AIMessageChunkV1] = [] + try: + async for msg in self._astream(input_messages, **kwargs): + await run_manager.on_llm_new_token(msg) + chunks.append(msg) + except BaseException as e: + await run_manager.on_llm_error( + e, response=_generate_response_from_error(e) + ) + raise + msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + else: + try: + msg = await self._ainvoke(input_messages, **kwargs) + except BaseException as e: + await run_manager.on_llm_error( + e, response=_generate_response_from_error(e) + ) + raise + + await run_manager.on_llm_end(msg.to_message()) + return msg + + @override + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[AIMessageChunkV1]: + if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): + # model doesn't implement streaming, so use default implementation + yield cast( + "AIMessageChunkV1", + self.invoke(input, config=config, **kwargs), + ) + else: + config = ensure_config(config) + messages = self._convert_input(input) + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(**kwargs) + options = {**kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + ) + (run_manager,) = callback_manager.on_chat_model_start( + {}, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=config.get("run_name"), + run_id=config.pop("run_id", None), + batch_size=1, + ) + + chunks: list[AIMessageChunkV1] = [] + + if self.rate_limiter: + self.rate_limiter.acquire(blocking=True) + + try: + # TODO: replace this with something for new messages + input_messages = _normalize_messages(messages) + for msg in self._stream(input_messages, **kwargs): + run_manager.on_llm_new_token(msg) + chunks.append(msg) + yield msg + except BaseException as e: + run_manager.on_llm_error(e, response=_generate_response_from_error(e)) + raise + + msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + run_manager.on_llm_end(msg) + + @override + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[AIMessageChunkV1]: + if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): + # No async or sync stream is implemented, so fall back to ainvoke + yield cast( + "AIMessageChunkV1", + await self.ainvoke(input, config=config, **kwargs), + ) + return + + config = ensure_config(config) + messages = self._convert_input(input) + + ls_structured_output_format = kwargs.pop( + "ls_structured_output_format", None + ) or kwargs.pop("structured_output_format", None) + ls_structured_output_format_dict = _format_ls_structured_output( + ls_structured_output_format + ) + + params = self._get_invocation_params(**kwargs) + options = {**kwargs, **ls_structured_output_format_dict} + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + inheritable_metadata, + self.metadata, + ) + (run_manager,) = await callback_manager.on_chat_model_start( + {}, + [_format_for_tracing(messages)], + invocation_params=params, + options=options, + name=config.get("run_name"), + run_id=config.pop("run_id", None), + batch_size=1, + ) + + if self.rate_limiter: + await self.rate_limiter.aacquire(blocking=True) + + chunks: list[AIMessageChunkV1] = [] + + try: + input_messages = _normalize_messages(messages) + async for msg in self._astream( + input_messages, + **kwargs, + ): + await run_manager.on_llm_new_token(msg) + chunks.append(msg) + yield msg + except BaseException as e: + await run_manager.on_llm_error(e, response=_generate_response_from_error(e)) + raise + + msg = add_ai_message_chunks(chunks[0], *chunks[1:]) + await run_manager.on_llm_end(msg) + + # --- Custom methods --- + + def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002 + return {} + + def _get_invocation_params( + self, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict: + params = self.dict() + params["stop"] = stop + return {**params, **kwargs} + + def _get_ls_params( + self, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> LangSmithParams: + """Get standard params for tracing.""" + # get default provider from class name + default_provider = self.__class__.__name__ + if default_provider.startswith("Chat"): + default_provider = default_provider[4:].lower() + elif default_provider.endswith("Chat"): + default_provider = default_provider[:-4] + default_provider = default_provider.lower() + + ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat") + if stop: + ls_params["ls_stop"] = stop + + # model + if hasattr(self, "model") and isinstance(self.model, str): + ls_params["ls_model_name"] = self.model + elif hasattr(self, "model_name") and isinstance(self.model_name, str): + ls_params["ls_model_name"] = self.model_name + + # temperature + if "temperature" in kwargs and isinstance(kwargs["temperature"], float): + ls_params["ls_temperature"] = kwargs["temperature"] + elif hasattr(self, "temperature") and isinstance(self.temperature, float): + ls_params["ls_temperature"] = self.temperature + + # max_tokens + if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int): + ls_params["ls_max_tokens"] = kwargs["max_tokens"] + elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int): + ls_params["ls_max_tokens"] = self.max_tokens + + return ls_params + + def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str: + params = self._get_invocation_params(stop=stop, **kwargs) + params = {**params, **kwargs} + return str(sorted(params.items())) + + def _invoke( + self, + messages: list[MessageV1], + **kwargs: Any, + ) -> AIMessage: + raise NotImplementedError + + async def _ainvoke( + self, + messages: list[MessageV1], + **kwargs: Any, + ) -> AIMessage: + return await run_in_executor( + None, + self._invoke, + messages, + **kwargs, + ) + + def _stream( + self, + messages: list[MessageV1], + **kwargs: Any, + ) -> Iterator[AIMessageChunkV1]: + raise NotImplementedError + + async def _astream( + self, + messages: list[MessageV1], + **kwargs: Any, + ) -> AsyncIterator[AIMessageChunkV1]: + iterator = await run_in_executor( + None, + self._stream, + messages, + **kwargs, + ) + done = object() + while True: + item = await run_in_executor( + None, + next, + iterator, + done, + ) + if item is done: + break + yield item # type: ignore[misc] + + @property + @abstractmethod + def _llm_type(self) -> str: + """Return type of chat model.""" + + @override + def dict(self, **kwargs: Any) -> dict: + """Return a dictionary of the LLM.""" + starter_dict = dict(self._identifying_params) + starter_dict["_type"] = self._llm_type + return starter_dict + + def bind_tools( + self, + tools: Sequence[ + Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 + ], + *, + tool_choice: Optional[Union[str]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tools to the model. + + Args: + tools: Sequence of tools to bind to the model. + tool_choice: The tool to use. If "any" then any tool can be used. + + Returns: + A Runnable that returns a message. + """ + raise NotImplementedError + + def with_structured_output( + self, + schema: Union[typing.Dict, type], # noqa: UP006 + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006 + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: + The output schema. Can be passed in as: + - an OpenAI function/tool schema, + - a JSON Schema, + - a TypedDict class, + - or a Pydantic class. + If ``schema`` is a Pydantic class then the model output will be a + Pydantic instance of that class, and the model-generated fields will be + validated by the Pydantic class. Otherwise the model output will be a + dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool` + for more on how to properly specify types and descriptions of + schema fields when specifying a Pydantic or TypedDict class. + + include_raw: + If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. + + If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs + an instance of ``schema`` (i.e., a Pydantic object). + + Otherwise, if ``include_raw`` is False then Runnable outputs a dict. + + If ``include_raw`` is True, then Runnable outputs a dict with keys: + - ``"raw"``: BaseMessage + - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. + - ``"parsing_error"``: Optional[BaseException] + + Example: Pydantic schema (include_raw=False): + .. code-block:: python + + from pydantic import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm = ChatModel(model="model-name", temperature=0) + structured_llm = llm.with_structured_output(AnswerWithJustification) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + + # -> AnswerWithJustification( + # answer='They weigh the same', + # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' + # ) + + Example: Pydantic schema (include_raw=True): + .. code-block:: python + + from pydantic import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm = ChatModel(model="model-name", temperature=0) + structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), + # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), + # 'parsing_error': None + # } + + Example: Dict schema (include_raw=False): + .. code-block:: python + + from pydantic import BaseModel + from langchain_core.utils.function_calling import convert_to_openai_tool + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + dict_schema = convert_to_openai_tool(AnswerWithJustification) + llm = ChatModel(model="model-name", temperature=0) + structured_llm = llm.with_structured_output(dict_schema) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'answer': 'They weigh the same', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' + # } + + .. versionchanged:: 0.2.26 + + Added support for TypedDict class. + """ # noqa: E501 + _ = kwargs.pop("method", None) + _ = kwargs.pop("strict", None) + if kwargs: + msg = f"Received unsupported arguments {kwargs}" + raise ValueError(msg) + + from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, + ) + + if type(self).bind_tools is BaseChatModelV1.bind_tools: + msg = "with_structured_output is not implemented for this model." + raise NotImplementedError(msg) + + llm = self.bind_tools( + [schema], + tool_choice="any", + ls_structured_output_format={ + "kwargs": {"method": "function_calling"}, + "schema": schema, + }, + ) + if isinstance(schema, type) and is_basemodel_subclass(schema): + output_parser: OutputParserLike = PydanticToolsParser( + tools=[cast("TypeBaseModel", schema)], first_tool_only=True + ) + else: + key_name = convert_to_openai_tool(schema)["function"]["name"] + output_parser = JsonOutputKeyToolsParser( + key_name=key_name, first_tool_only=True + ) + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + return llm | output_parser + + +def _gen_info_and_msg_metadata( + generation: Union[ChatGeneration, ChatGenerationChunk], +) -> dict: + return { + **(generation.generation_info or {}), + **generation.message.response_metadata, + } diff --git a/libs/core/langchain_core/messages/__init__.py b/libs/core/langchain_core/messages/__init__.py index fe87e964af2..0faf0447295 100644 --- a/libs/core/langchain_core/messages/__init__.py +++ b/libs/core/langchain_core/messages/__init__.py @@ -33,6 +33,24 @@ if TYPE_CHECKING: ) from langchain_core.messages.chat import ChatMessage, ChatMessageChunk from langchain_core.messages.content_blocks import ( + Annotation, + AudioContentBlock, + Citation, + CodeInterpreterCall, + CodeInterpreterOutput, + CodeInterpreterResult, + ContentBlock, + DataContentBlock, + FileContentBlock, + ImageContentBlock, + NonStandardAnnotation, + NonStandardContentBlock, + PlainTextContentBlock, + ReasoningContentBlock, + TextContentBlock, + VideoContentBlock, + WebSearchCall, + WebSearchResult, convert_to_openai_data_block, convert_to_openai_image_block, is_data_content_block, @@ -65,24 +83,42 @@ if TYPE_CHECKING: __all__ = ( "AIMessage", "AIMessageChunk", + "Annotation", "AnyMessage", + "AudioContentBlock", "BaseMessage", "BaseMessageChunk", "ChatMessage", "ChatMessageChunk", + "Citation", + "CodeInterpreterCall", + "CodeInterpreterOutput", + "CodeInterpreterResult", + "ContentBlock", + "DataContentBlock", + "FileContentBlock", "FunctionMessage", "FunctionMessageChunk", "HumanMessage", "HumanMessageChunk", + "ImageContentBlock", "InvalidToolCall", "MessageLikeRepresentation", + "NonStandardAnnotation", + "NonStandardContentBlock", + "PlainTextContentBlock", + "ReasoningContentBlock", "RemoveMessage", "SystemMessage", "SystemMessageChunk", + "TextContentBlock", "ToolCall", "ToolCallChunk", "ToolMessage", "ToolMessageChunk", + "VideoContentBlock", + "WebSearchCall", + "WebSearchResult", "_message_from_dict", "convert_to_messages", "convert_to_openai_data_block", @@ -103,25 +139,43 @@ __all__ = ( _dynamic_imports = { "AIMessage": "ai", "AIMessageChunk": "ai", + "Annotation": "content_blocks", + "AudioContentBlock": "content_blocks", "BaseMessage": "base", "BaseMessageChunk": "base", "merge_content": "base", "message_to_dict": "base", "messages_to_dict": "base", + "Citation": "content_blocks", + "ContentBlock": "content_blocks", "ChatMessage": "chat", "ChatMessageChunk": "chat", + "CodeInterpreterCall": "content_blocks", + "CodeInterpreterOutput": "content_blocks", + "CodeInterpreterResult": "content_blocks", + "DataContentBlock": "content_blocks", + "FileContentBlock": "content_blocks", "FunctionMessage": "function", "FunctionMessageChunk": "function", "HumanMessage": "human", "HumanMessageChunk": "human", + "NonStandardAnnotation": "content_blocks", + "NonStandardContentBlock": "content_blocks", + "PlainTextContentBlock": "content_blocks", + "ReasoningContentBlock": "content_blocks", "RemoveMessage": "modifier", "SystemMessage": "system", "SystemMessageChunk": "system", + "WebSearchCall": "content_blocks", + "WebSearchResult": "content_blocks", + "ImageContentBlock": "content_blocks", "InvalidToolCall": "tool", + "TextContentBlock": "content_blocks", "ToolCall": "tool", "ToolCallChunk": "tool", "ToolMessage": "tool", "ToolMessageChunk": "tool", + "VideoContentBlock": "content_blocks", "AnyMessage": "utils", "MessageLikeRepresentation": "utils", "_message_from_dict": "utils", diff --git a/libs/core/langchain_core/messages/content_blocks.py b/libs/core/langchain_core/messages/content_blocks.py index 83a66fb123a..a2efd0fef9c 100644 --- a/libs/core/langchain_core/messages/content_blocks.py +++ b/libs/core/langchain_core/messages/content_blocks.py @@ -1,110 +1,782 @@ -"""Types for content blocks.""" +"""Standard, multimodal content blocks for Large Language Model I/O. + +.. warning:: + This module is under active development. The API is unstable and subject to + change in future releases. + +This module provides a standardized data structure for representing inputs to and +outputs from Large Language Models. The core abstraction is the **Content Block**, a +``TypedDict`` that can represent a piece of text, an image, a tool call, or other +structured data. + +Data **not yet mapped** to a standard block may be represented using the +``NonStandardContentBlock``, which allows for provider-specific data to be included +without losing the benefits of type checking and validation. + +Furthermore, provider-specific fields *within* a standard block will be allowed as extra +keys on the TypedDict per `PEP 728 `__. This allows +for flexibility in the data structure while maintaining a consistent interface. + +**Example using ``extra_items=Any``:** + +.. code-block:: python + from langchain_core.messages.content_blocks import TextContentBlock + from typing import Any + + my_block: TextContentBlock = { + "type": "text", + "text": "Hello, world!", + "extra_field": "This is allowed", + "another_field": 42, # Any type is allowed + } + + # A type checker that supports PEP 728 would validate the object above. + # Accessing the provider-specific key is possible, and its type is 'Any'. + block_extra_field = my_block["extra_field"] + +.. warning:: + Type checkers such as MyPy do not yet support `PEP 728 `__, + so you may see type errors when using provider-specific fields. These are safe to + ignore, as the fields are still validated at runtime. + +**Rationale** + +Different LLM providers use distinct and incompatible API schemas. This module +introduces a unified, provider-agnostic format to standardize these interactions. A +message to or from a model is simply a `list` of `ContentBlock` objects, allowing for +the natural interleaving of text, images, and other content in a single, ordered +sequence. + +An adapter for a specific provider is responsible for translating this standard list of +blocks into the format required by its API. + +**Key Block Types** + +The module defines several types of content blocks, including: + +- **``TextContentBlock``**: Standard text. +- **``ImageContentBlock``**, **``AudioContentBlock``**, **``VideoContentBlock``**: For + multimodal data. +- **``ToolCallContentBlock``**, **``ToolOutputContentBlock``**: For function calling. +- **``ReasoningContentBlock``**: To capture a model's thought process. +- **``Citation``**: For annotations that link generated text to a source document. + +**Example Usage** + +.. code-block:: python + + from langchain_core.messages.content_blocks import TextContentBlock, ImageContentBlock + + multimodal_message: AIMessage = [ + TextContentBlock(type="text", text="What is shown in this image?"), + ImageContentBlock( + type="image", + url="https://www.langchain.com/images/brand/langchain_logo_text_w_white.png", + mime_type="image/png", + ), + ] +""" # noqa: E501 import warnings -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union -from pydantic import TypeAdapter, ValidationError -from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired, TypedDict, get_args, get_origin + +# --- Text and annotations --- -class BaseDataContentBlock(TypedDict, total=False): - """Base class for data content blocks.""" +class Citation(TypedDict): + """Annotation for citing data from a document. + + .. note:: + ``start/end`` indices refer to the **response text**, + not the source text. This means that the indices are relative to the model's + response, not the original document (as specified in the ``url``). + """ + + type: Literal["citation"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + url: NotRequired[str] + """URL of the document source.""" + + # For future consideration, if needed: + # provenance: NotRequired[str] + # """Provenance of the document, e.g., "Wikipedia", "arXiv", etc. + + # Included for future compatibility; not currently implemented. + # """ + + title: NotRequired[str] + """Source document title. + + For example, the page title for a web page or the title of a paper. + """ + + start_index: NotRequired[int] + """Start index of the **response text** (``TextContentBlock.text``) for which the + annotation applies.""" + + end_index: NotRequired[int] + """End index of the **response text** (``TextContentBlock.text``) for which the + annotation applies.""" + + cited_text: NotRequired[str] + """Excerpt of source text being cited.""" + + # NOTE: not including spans for the raw document text (such as `text_start_index` + # and `text_end_index`) as this is not currently supported by any provider. The + # thinking is that the `cited_text` should be sufficient for most use cases, and it + # is difficult to reliably extract spans from the raw document text across file + # formats or encoding schemes. + + +class NonStandardAnnotation(TypedDict): + """Provider-specific annotation format.""" + + type: Literal["non_standard_annotation"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + value: dict[str, Any] + """Provider-specific annotation data.""" + + +Annotation = Union[Citation, NonStandardAnnotation] + + +class TextContentBlock(TypedDict): + """Content block for text output. + + This typically represents the main text content of a message, such as the response + from a language model or the text of a user message. + """ + + type: Literal["text"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + text: str + """Block text.""" + + annotations: NotRequired[list[Annotation]] + """Citations and other annotations.""" + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + +# --- Tool calls --- +class ToolCall(TypedDict): + """Represents a request to call a tool. + + Example: + + .. code-block:: python + + { + "name": "foo", + "args": {"a": 1}, + "id": "123" + } + + This represents a request to call the tool named "foo" with arguments {"a": 1} + and an identifier of "123". + """ + + name: str + """The name of the tool to be called.""" + args: dict[str, Any] + """The arguments to the tool call.""" + id: Optional[str] + """An identifier associated with the tool call. + + An identifier is needed to associate a tool call request with a tool + call result in events when multiple concurrent tool calls are made. + """ + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + type: Literal["tool_call"] + + +class InvalidToolCall(TypedDict): + """Allowance for errors made by LLM. + + Here we add an `error` key to surface errors made during generation + (e.g., invalid JSON arguments.) + """ + + name: Optional[str] + """The name of the tool to be called.""" + args: Optional[str] + """The arguments to the tool call.""" + id: Optional[str] + """An identifier associated with the tool call.""" + error: Optional[str] + """An error message associated with the tool call.""" + type: Literal["invalid_tool_call"] + + +class ToolCallChunk(TypedDict): + """A chunk of a tool call (e.g., as part of a stream). + + When merging ToolCallChunks (e.g., via AIMessageChunk.__add__), + all string attributes are concatenated. Chunks are only merged if their + values of `index` are equal and not None. + + Example: + + .. code-block:: python + + left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)] + right_chunks = [ToolCallChunk(name=None, args='1}', index=0)] + + ( + AIMessageChunk(content="", tool_call_chunks=left_chunks) + + AIMessageChunk(content="", tool_call_chunks=right_chunks) + ).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)] + """ + + name: Optional[str] + """The name of the tool to be called.""" + args: Optional[str] + """The arguments to the tool call.""" + id: Optional[str] + """An identifier associated with the tool call.""" + index: Optional[int] + """The index of the tool call in a sequence.""" + type: NotRequired[Literal["tool_call_chunk"]] + + +# --- Provider tool calls (built-in tools) --- +# Note: These are not standard tool calls, but rather provider-specific built-in tools. + + +# Web search +class WebSearchCall(TypedDict): + """Content block for a built-in web search tool call.""" + + type: Literal["web_search_call"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + query: NotRequired[str] + """The search query used in the web search tool call.""" + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + +class WebSearchResult(TypedDict): + """Content block for the result of a built-in web search tool call.""" + + type: Literal["web_search_result"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + urls: NotRequired[list[str]] + """List of URLs returned by the web search tool call.""" + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + +# Code interpreter + + +# Call +class CodeInterpreterCall(TypedDict): + """Content block for a built-in code interpreter tool call.""" + + type: Literal["code_interpreter_call"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + language: NotRequired[str] + """The programming language used in the code interpreter tool call.""" + + code: NotRequired[str] + """The code to be executed by the code interpreter.""" + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + +# Result block is CodeInterpreterResult +class CodeInterpreterOutput(TypedDict): + """Content block for the output of a singular code interpreter tool call. + + Full output of a code interpreter tool call is represented by + ``CodeInterpreterResult`` which is a list of these blocks. + """ + + type: Literal["code_interpreter_output"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + return_code: NotRequired[int] + """Return code of the executed code. + + Example: 0 for success, non-zero for failure. + """ + + stderr: NotRequired[str] + """Standard error output of the executed code.""" + + stdout: NotRequired[str] + """Standard output of the executed code.""" + + file_ids: NotRequired[list[str]] + """List of file IDs generated by the code interpreter.""" + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + +class CodeInterpreterResult(TypedDict): + """Content block for the result of a code interpreter tool call.""" + + type: Literal["code_interpreter_result"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + output: list[CodeInterpreterOutput] + """List of outputs from the code interpreter tool call.""" + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + +# --- Reasoning --- +class ReasoningContentBlock(TypedDict): + """Content block for reasoning output.""" + + type: Literal["reasoning"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + reasoning: NotRequired[str] + """Reasoning text. + + Either the thought summary or the raw reasoning text itself. This is often parsed + from ```` tags in the model's response. + """ + + thought_signature: NotRequired[str] + """Opaque state handle representation of the model's internal thought process. + + Maintains the context of the model's thinking across multiple interactions + (e.g. multi-turn conversations) since many APIs are stateless. + + Not to be used to verify authenticity or integrity of the response (`'signature'`). + + Examples: + - https://ai.google.dev/gemini-api/docs/thinking#signatures + """ + + signature: NotRequired[str] + """Signature of the reasoning content block used to verify **authenticity**. + + Prevents from modifying or fabricating the model's reasoning process. + + Examples: + - https://docs.anthropic.com/en/docs/build-with-claude/context-windows#the-context-window-with-extended-thinking-and-tool-use + """ + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + +# --- Multi-modal --- + + +# Note: `title` and `context` are fields that could be used to provide additional +# information about the file, such as a description or summary of its content. +# E.g. with Claude, you can provide a context for a file which is passed to the model. +class ImageContentBlock(TypedDict): + """Content block for image data.""" + + type: Literal["image"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + file_id: NotRequired[str] + """ID of the image file, e.g., from a file storage system.""" mime_type: NotRequired[str] - """MIME type of the content block (if needed).""" + """MIME type of the image. Required for base64. + `Examples from IANA `__ + """ -class URLContentBlock(BaseDataContentBlock): - """Content block for data from a URL.""" + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" - type: Literal["image", "audio", "file"] - """Type of the content block.""" - source_type: Literal["url"] - """Source type (url).""" - url: str - """URL for data.""" + url: NotRequired[str] + """URL of the image.""" - -class Base64ContentBlock(BaseDataContentBlock): - """Content block for inline data from a base64 string.""" - - type: Literal["image", "audio", "file"] - """Type of the content block.""" - source_type: Literal["base64"] - """Source type (base64).""" - data: str + base64: NotRequired[str] """Data as a base64 string.""" + # title: NotRequired[str] + # """Title of the image.""" -class PlainTextContentBlock(BaseDataContentBlock): - """Content block for plain text data (e.g., from a document).""" + # context: NotRequired[str] + # """Context for the image, e.g., a description or summary of the image's content.""" # noqa: E501 + + +class VideoContentBlock(TypedDict): + """Content block for video data.""" + + type: Literal["video"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + file_id: NotRequired[str] + """ID of the video file, e.g., from a file storage system.""" + + mime_type: NotRequired[str] + """MIME type of the video. Required for base64. + + `Examples from IANA `__ + """ + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + url: NotRequired[str] + """URL of the video.""" + + base64: NotRequired[str] + """Data as a base64 string.""" + + # title: NotRequired[str] + # """Title of the video.""" + + # context: NotRequired[str] + # """Context for the video, e.g., a description or summary of the video's content.""" # noqa: E501 + + +class AudioContentBlock(TypedDict): + """Content block for audio data.""" + + type: Literal["audio"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + file_id: NotRequired[str] + """ID of the audio file, e.g., from a file storage system.""" + + mime_type: NotRequired[str] + """MIME type of the audio. Required for base64. + + `Examples from IANA `__ + """ + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + url: NotRequired[str] + """URL of the audio.""" + + base64: NotRequired[str] + """Data as a base64 string.""" + + # title: NotRequired[str] + # """Title of the audio.""" + + # context: NotRequired[str] + # """Context for the audio, e.g., a description or summary of the audio's content.""" # noqa: E501 + + +class PlainTextContentBlock(TypedDict): + """Content block for plaintext data (e.g., from a document). + + .. note:: + Title and context are optional fields that may be passed to the model. See + Anthropic `example `__. + """ + + type: Literal["text-plain"] + """Type of the content block.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + file_id: NotRequired[str] + """ID of the plaintext file, e.g., from a file storage system.""" + + mime_type: Literal["text/plain"] + """MIME type of the file. Required for base64.""" + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + url: NotRequired[str] + """URL of the plaintext.""" + + base64: NotRequired[str] + """Data as a base64 string.""" + + text: NotRequired[str] + """Plaintext content. This is optional if the data is provided as base64.""" + + title: NotRequired[str] + """Title of the text data, e.g., the title of a document.""" + + context: NotRequired[str] + """Context for the text, e.g., a description or summary of the text's content.""" + + +class FileContentBlock(TypedDict): + """Content block for file data. + + This block is intended for files that are not images, audio, or plaintext. For + example, it can be used for PDFs, Word documents, etc. + + If the file is an image, audio, or plaintext, you should use the corresponding + content block type (e.g., ``ImageContentBlock``, ``AudioContentBlock``, + ``PlainTextContentBlock``). + """ type: Literal["file"] """Type of the content block.""" - source_type: Literal["text"] - """Source type (text).""" - text: str - """Text data.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + file_id: NotRequired[str] + """ID of the file, e.g., from a file storage system.""" + + mime_type: NotRequired[str] + """MIME type of the file. Required for base64. + + `Examples from IANA `__ + """ + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" + + url: NotRequired[str] + """URL of the file.""" + + base64: NotRequired[str] + """Data as a base64 string.""" + + # title: NotRequired[str] + # """Title of the file, e.g., the name of a document or file.""" + + # context: NotRequired[str] + # """Context for the file, e.g., a description or summary of the file's content.""" -class IDContentBlock(TypedDict): - """Content block for data specified by an identifier.""" +# Future modalities to consider: +# - 3D models +# - Tabular data - type: Literal["image", "audio", "file"] + +# Non-standard +class NonStandardContentBlock(TypedDict): + """Content block provider-specific data. + + This block contains data for which there is not yet a standard type. + + The purpose of this block should be to simply hold a provider-specific payload. + If a provider's non-standard output includes reasoning and tool calls, it should be + the adapter's job to parse that payload and emit the corresponding standard + ReasoningContentBlock and ToolCallContentBlocks. + """ + + type: Literal["non_standard"] """Type of the content block.""" - source_type: Literal["id"] - """Source type (id).""" - id: str - """Identifier for data source.""" + + id: NotRequired[str] + """Content block identifier. Either: + + - Generated by the provider (e.g., OpenAI's file ID) + - Generated by LangChain upon creation (as ``UUID4``) + """ + + value: dict[str, Any] + """Provider-specific data.""" + + index: NotRequired[int] + """Index of block in aggregate response. Used during streaming.""" +# --- Aliases --- DataContentBlock = Union[ - URLContentBlock, - Base64ContentBlock, + ImageContentBlock, + VideoContentBlock, + AudioContentBlock, PlainTextContentBlock, - IDContentBlock, + FileContentBlock, ] -_DataContentBlockAdapter: TypeAdapter[DataContentBlock] = TypeAdapter(DataContentBlock) +ToolContentBlock = Union[ + ToolCall, + CodeInterpreterCall, + CodeInterpreterOutput, + CodeInterpreterResult, + WebSearchCall, + WebSearchResult, +] + +ContentBlock = Union[ + TextContentBlock, + ToolCall, + ReasoningContentBlock, + NonStandardContentBlock, + DataContentBlock, + ToolContentBlock, +] -def is_data_content_block( - content_block: dict, -) -> bool: +def _extract_typedict_type_values(union_type: Any) -> set[str]: + """Extract the values of the 'type' field from a TypedDict union type.""" + result: set[str] = set() + for value in get_args(union_type): + annotation = value.__annotations__["type"] + if get_origin(annotation) is Literal: + result.update(get_args(annotation)) + else: + msg = f"{value} 'type' is not a Literal" + raise ValueError(msg) + return result + + +KNOWN_BLOCK_TYPES = { + bt for bt in get_args(ContentBlock) for bt in get_args(bt.__annotations__["type"]) +} + + +def is_data_content_block(block: dict) -> bool: """Check if the content block is a standard data content block. Args: - content_block: The content block to check. + block: The content block to check. Returns: True if the content block is a data content block, False otherwise. """ - try: - _ = _DataContentBlockAdapter.validate_python(content_block) - except ValidationError: - return False - else: - return True + return block.get("type") in ( + "audio", + "image", + "video", + "file", + "text-plain", + ) and any( + key in block + for key in ( + "url", + "base64", + "file_id", + "text", + "source_type", # backwards compatibility + ) + ) -def convert_to_openai_image_block(content_block: dict[str, Any]) -> dict: +def convert_to_openai_image_block(block: dict[str, Any]) -> dict: """Convert image content block to format expected by OpenAI Chat Completions API.""" - if content_block["source_type"] == "url": + if "url" in block: return { "type": "image_url", "image_url": { - "url": content_block["url"], + "url": block["url"], }, } - if content_block["source_type"] == "base64": - if "mime_type" not in content_block: + if "base64" in block or block.get("source_type") == "base64": + if "mime_type" not in block: error_message = "mime_type key is required for base64 data." raise ValueError(error_message) - mime_type = content_block["mime_type"] + mime_type = block["mime_type"] + base64_data = block["data"] if "data" in block else block["base64"] return { "type": "image_url", "image_url": { - "url": f"data:{mime_type};base64,{content_block['data']}", + "url": f"data:{mime_type};base64,{base64_data}", }, } error_message = "Unsupported source type. Only 'url' and 'base64' are supported." @@ -117,8 +789,9 @@ def convert_to_openai_data_block(block: dict) -> dict: formatted_block = convert_to_openai_image_block(block) elif block["type"] == "file": - if block["source_type"] == "base64": - file = {"file_data": f"data:{block['mime_type']};base64,{block['data']}"} + if "base64" in block or block.get("source_type") == "base64": + base64_data = block["data"] if "source_type" in block else block["base64"] + file = {"file_data": f"data:{block['mime_type']};base64,{base64_data}"} if filename := block.get("filename"): file["filename"] = filename elif (metadata := block.get("metadata")) and ("filename" in metadata): @@ -126,27 +799,28 @@ def convert_to_openai_data_block(block: dict) -> dict: else: warnings.warn( "OpenAI may require a filename for file inputs. Specify a filename " - "in the content block: {'type': 'file', 'source_type': 'base64', " - "'mime_type': 'application/pdf', 'data': '...', " - "'filename': 'my-pdf'}", + "in the content block: {'type': 'file', 'mime_type': " + "'application/pdf', 'base64': '...', 'filename': 'my-pdf'}", stacklevel=1, ) formatted_block = {"type": "file", "file": file} - elif block["source_type"] == "id": - formatted_block = {"type": "file", "file": {"file_id": block["id"]}} + elif "file_id" in block or block.get("source_type") == "id": + file_id = block["id"] if "source_type" in block else block["file_id"] + formatted_block = {"type": "file", "file": {"file_id": file_id}} else: - error_msg = "source_type base64 or id is required for file blocks." + error_msg = "Keys base64 or file_id required for file blocks." raise ValueError(error_msg) elif block["type"] == "audio": - if block["source_type"] == "base64": + if "base64" in block or block.get("source_type") == "base64": + base64_data = block["data"] if "source_type" in block else block["base64"] audio_format = block["mime_type"].split("/")[-1] formatted_block = { "type": "input_audio", - "input_audio": {"data": block["data"], "format": audio_format}, + "input_audio": {"data": base64_data, "format": audio_format}, } else: - error_msg = "source_type base64 is required for audio blocks." + error_msg = "Key base64 is required for audio blocks." raise ValueError(error_msg) else: error_msg = f"Block of type {block['type']} is not supported." diff --git a/libs/core/langchain_core/messages/modifier.py b/libs/core/langchain_core/messages/modifier.py index 08b7e79b69c..5f1602a4908 100644 --- a/libs/core/langchain_core/messages/modifier.py +++ b/libs/core/langchain_core/messages/modifier.py @@ -13,7 +13,7 @@ class RemoveMessage(BaseMessage): def __init__( self, - id: str, # noqa: A002 + id: str, **kwargs: Any, ) -> None: """Create a RemoveMessage. diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 1f8a519a7dc..181c80443d5 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -5,9 +5,12 @@ from typing import Any, Literal, Optional, Union from uuid import UUID from pydantic import Field, model_validator -from typing_extensions import NotRequired, TypedDict, override +from typing_extensions import override from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content +from langchain_core.messages.content_blocks import InvalidToolCall as InvalidToolCall +from langchain_core.messages.content_blocks import ToolCall as ToolCall +from langchain_core.messages.content_blocks import ToolCallChunk as ToolCallChunk from langchain_core.utils._merge import merge_dicts, merge_obj @@ -177,42 +180,11 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk): return super().__add__(other) -class ToolCall(TypedDict): - """Represents a request to call a tool. - - Example: - - .. code-block:: python - - { - "name": "foo", - "args": {"a": 1}, - "id": "123" - } - - This represents a request to call the tool named "foo" with arguments {"a": 1} - and an identifier of "123". - - """ - - name: str - """The name of the tool to be called.""" - args: dict[str, Any] - """The arguments to the tool call.""" - id: Optional[str] - """An identifier associated with the tool call. - - An identifier is needed to associate a tool call request with a tool - call result in events when multiple concurrent tool calls are made. - """ - type: NotRequired[Literal["tool_call"]] - - def tool_call( *, name: str, args: dict[str, Any], - id: Optional[str], # noqa: A002 + id: Optional[str], ) -> ToolCall: """Create a tool call. @@ -224,43 +196,11 @@ def tool_call( return ToolCall(name=name, args=args, id=id, type="tool_call") -class ToolCallChunk(TypedDict): - """A chunk of a tool call (e.g., as part of a stream). - - When merging ToolCallChunks (e.g., via AIMessageChunk.__add__), - all string attributes are concatenated. Chunks are only merged if their - values of `index` are equal and not None. - - Example: - - .. code-block:: python - - left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)] - right_chunks = [ToolCallChunk(name=None, args='1}', index=0)] - - ( - AIMessageChunk(content="", tool_call_chunks=left_chunks) - + AIMessageChunk(content="", tool_call_chunks=right_chunks) - ).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)] - - """ - - name: Optional[str] - """The name of the tool to be called.""" - args: Optional[str] - """The arguments to the tool call.""" - id: Optional[str] - """An identifier associated with the tool call.""" - index: Optional[int] - """The index of the tool call in a sequence.""" - type: NotRequired[Literal["tool_call_chunk"]] - - def tool_call_chunk( *, name: Optional[str] = None, args: Optional[str] = None, - id: Optional[str] = None, # noqa: A002 + id: Optional[str] = None, index: Optional[int] = None, ) -> ToolCallChunk: """Create a tool call chunk. @@ -276,29 +216,11 @@ def tool_call_chunk( ) -class InvalidToolCall(TypedDict): - """Allowance for errors made by LLM. - - Here we add an `error` key to surface errors made during generation - (e.g., invalid JSON arguments.) - """ - - name: Optional[str] - """The name of the tool to be called.""" - args: Optional[str] - """The arguments to the tool call.""" - id: Optional[str] - """An identifier associated with the tool call.""" - error: Optional[str] - """An error message associated with the tool call.""" - type: NotRequired[Literal["invalid_tool_call"]] - - def invalid_tool_call( *, name: Optional[str] = None, args: Optional[str] = None, - id: Optional[str] = None, # noqa: A002 + id: Optional[str] = None, error: Optional[str] = None, ) -> InvalidToolCall: """Create an invalid tool call. diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 11c044eb438..ed9e8d39745 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -40,6 +40,12 @@ from langchain_core.messages.human import HumanMessage, HumanMessageChunk from langchain_core.messages.modifier import RemoveMessage from langchain_core.messages.system import SystemMessage, SystemMessageChunk from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk +from langchain_core.messages.v1 import AIMessage as AIMessageV1 +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 +from langchain_core.messages.v1 import MessageV1, MessageV1Types +from langchain_core.messages.v1 import SystemMessage as SystemMessageV1 +from langchain_core.messages.v1 import ToolMessage as ToolMessageV1 if TYPE_CHECKING: from langchain_text_splitters import TextSplitter @@ -203,7 +209,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage: MessageLikeRepresentation = Union[ - BaseMessage, list[str], tuple[str, str], str, dict[str, Any] + BaseMessage, list[str], tuple[str, str], str, dict[str, Any], MessageV1 ] @@ -213,7 +219,7 @@ def _create_message_from_message_type( name: Optional[str] = None, tool_call_id: Optional[str] = None, tool_calls: Optional[list[dict[str, Any]]] = None, - id: Optional[str] = None, # noqa: A002 + id: Optional[str] = None, **additional_kwargs: Any, ) -> BaseMessage: """Create a message from a message type and content string. @@ -294,6 +300,128 @@ def _create_message_from_message_type( return message +def _create_message_from_message_type_v1( + message_type: str, + content: str, + name: Optional[str] = None, + tool_call_id: Optional[str] = None, + tool_calls: Optional[list[dict[str, Any]]] = None, + id: Optional[str] = None, + **kwargs: Any, +) -> MessageV1: + """Create a message from a message type and content string. + + Args: + message_type: (str) the type of the message (e.g., "human", "ai", etc.). + content: (str) the content string. + name: (str) the name of the message. Default is None. + tool_call_id: (str) the tool call id. Default is None. + tool_calls: (list[dict[str, Any]]) the tool calls. Default is None. + id: (str) the id of the message. Default is None. + kwargs: (dict[str, Any]) additional keyword arguments. + + Returns: + a message of the appropriate type. + + Raises: + ValueError: if the message type is not one of "human", "user", "ai", + "assistant", "tool", "system", or "developer". + """ + kwargs: dict[str, Any] = {} + if name is not None: + kwargs["name"] = name + if tool_call_id is not None: + kwargs["tool_call_id"] = tool_call_id + if kwargs and (response_metadata := kwargs.pop("response_metadata", None)): + kwargs["response_metadata"] = response_metadata + if id is not None: + kwargs["id"] = id + if tool_calls is not None: + kwargs["tool_calls"] = [] + for tool_call in tool_calls: + # Convert OpenAI-format tool call to LangChain format. + if "function" in tool_call: + args = tool_call["function"]["arguments"] + if isinstance(args, str): + args = json.loads(args, strict=False) + kwargs["tool_calls"].append( + { + "name": tool_call["function"]["name"], + "args": args, + "id": tool_call["id"], + "type": "tool_call", + } + ) + else: + kwargs["tool_calls"].append(tool_call) + if message_type in {"human", "user"}: + message = HumanMessageV1(content=content, **kwargs) + elif message_type in {"ai", "assistant"}: + message = AIMessageV1(content=content, **kwargs) + elif message_type in {"system", "developer"}: + if message_type == "developer": + kwargs["custom_role"] = "developer" + message = SystemMessageV1(content=content, **kwargs) + elif message_type == "tool": + artifact = kwargs.pop("artifact", None) + message = ToolMessageV1(content=content, artifact=artifact, **kwargs) + else: + msg = ( + f"Unexpected message type: '{message_type}'. Use one of 'human'," + f" 'user', 'ai', 'assistant', 'function', 'tool', 'system', or 'developer'." + ) + msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE) + raise ValueError(msg) + return message + + +def _convert_from_v1_message(message: MessageV1) -> BaseMessage: + """Compatibility layer to convert v1 messages to current messages. + + Args: + message: MessageV1 instance to convert. + + Returns: + BaseMessage: Converted message instance. + """ + # type ignores here are because AIMessageV1.content is a list of dicts. + # AIMessageV0.content expects str or list[str | dict]. + if isinstance(message, AIMessageV1): + return AIMessage( + content=message.content, # type: ignore[arg-type] + id=message.id, + name=message.name, + tool_calls=message.tool_calls, + response_metadata=message.response_metadata, + ) + if isinstance(message, AIMessageChunkV1): + return AIMessageChunk( + content=message.content, # type: ignore[arg-type] + id=message.id, + name=message.name, + tool_call_chunks=message.tool_call_chunks, + response_metadata=message.response_metadata, + ) + if isinstance(message, HumanMessageV1): + return HumanMessage( + content=message.content, # type: ignore[arg-type] + id=message.id, + name=message.name, + ) + if isinstance(message, SystemMessageV1): + return SystemMessage( + content=message.content, # type: ignore[arg-type] + id=message.id, + ) + if isinstance(message, ToolMessageV1): + return ToolMessage( + content=message.content, # type: ignore[arg-type] + id=message.id, + ) + message = f"Unsupported message type: {type(message)}" + raise NotImplementedError(message) + + def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: """Instantiate a message from a variety of message formats. @@ -341,6 +469,63 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: message_ = _create_message_from_message_type( msg_type, msg_content, **msg_kwargs ) + elif isinstance(message, MessageV1Types): + message_ = _convert_from_v1_message(message) + else: + msg = f"Unsupported message type: {type(message)}" + msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE) + raise NotImplementedError(msg) + + return message_ + + +def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1: + """Instantiate a message from a variety of message formats. + + The message format can be one of the following: + + - BaseMessagePromptTemplate + - BaseMessage + - 2-tuple of (role string, template); e.g., ("human", "{user_input}") + - dict: a message dict with role and content keys + - string: shorthand for ("human", template); e.g., "{user_input}" + + Args: + message: a representation of a message in one of the supported formats. + + Returns: + an instance of a message or a message template. + + Raises: + NotImplementedError: if the message type is not supported. + ValueError: if the message dict does not contain the required keys. + """ + if isinstance(message, MessageV1Types): + message_ = message + elif isinstance(message, str): + message_ = _create_message_from_message_type_v1("human", message) + elif isinstance(message, Sequence) and len(message) == 2: + # mypy doesn't realise this can't be a string given the previous branch + message_type_str, template = message # type: ignore[misc] + message_ = _create_message_from_message_type_v1(message_type_str, template) + elif isinstance(message, dict): + msg_kwargs = message.copy() + try: + try: + msg_type = msg_kwargs.pop("role") + except KeyError: + msg_type = msg_kwargs.pop("type") + # None msg content is not allowed + msg_content = msg_kwargs.pop("content") or "" + except KeyError as e: + msg = f"Message dict must contain 'role' and 'content' keys, got {message}" + msg = create_message( + message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE + ) + raise ValueError(msg) from e + message_ = _create_message_from_message_type_v1( + msg_type, msg_content, **msg_kwargs + ) else: msg = f"Unsupported message type: {type(message)}" msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE) @@ -368,6 +553,25 @@ def convert_to_messages( return [_convert_to_message(m) for m in messages] +def convert_to_messages_v1( + messages: Union[Iterable[MessageLikeRepresentation], PromptValue], +) -> list[MessageV1]: + """Convert a sequence of messages to a list of messages. + + Args: + messages: Sequence of messages to convert. + + Returns: + list of messages (BaseMessages). + """ + # Import here to avoid circular imports + from langchain_core.prompt_values import PromptValue + + if isinstance(messages, PromptValue): + return messages.to_messages(output_version="v1") + return [_convert_to_message_v1(m) for m in messages] + + def _runnable_support(func: Callable) -> Callable: @overload def wrapped( @@ -1007,10 +1211,11 @@ def convert_to_openai_messages( oai_messages: list = [] - if is_single := isinstance(messages, (BaseMessage, dict, str)): + if is_single := isinstance(messages, (BaseMessage, dict, str, MessageV1Types)): messages = [messages] - messages = convert_to_messages(messages) + # TODO: resolve type ignore here + messages = convert_to_messages(messages) # type: ignore[arg-type] for i, message in enumerate(messages): oai_msg: dict = {"role": _get_message_openai_role(message)} diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py new file mode 100644 index 00000000000..9ff2eaed4ab --- /dev/null +++ b/libs/core/langchain_core/messages/v1.py @@ -0,0 +1,568 @@ +"""LangChain 1.0 message format.""" + +import json +import uuid +from dataclasses import dataclass, field +from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args + +import langchain_core.messages.content_blocks as types +from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage +from langchain_core.messages.base import merge_content +from langchain_core.messages.tool import ( + ToolCallChunk, +) +from langchain_core.messages.tool import ( + invalid_tool_call as create_invalid_tool_call, +) +from langchain_core.messages.tool import tool_call as create_tool_call +from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk +from langchain_core.utils._merge import merge_dicts, merge_lists +from langchain_core.utils.json import parse_partial_json + + +def _ensure_id(id_val: Optional[str]) -> str: + """Ensure the ID is a valid string, generating a new UUID if not provided. + + Args: + id_val: Optional string ID value to validate. + + Returns: + A valid string ID, either the provided value or a new UUID. + """ + return id_val or str(uuid.uuid4()) + + +class Provider(TypedDict): + """Information about the provider that generated the message. + + Contains metadata about the AI provider and model used to generate content. + + Attributes: + name: Name and version of the provider that created the content block. + model_name: Name of the model that generated the content block. + """ + + name: str + """Name and version of the provider that created the content block.""" + model_name: str + """Name of the model that generated the content block.""" + + +@dataclass +class AIMessage: + """A message generated by an AI assistant. + + Represents a response from an AI model, including text content, tool calls, + and metadata about the generation process. + + Attributes: + id: Unique identifier for the message. + type: Message type identifier, always "ai". + name: Optional human-readable name for the message. + lc_version: Encoding version for the message. + content: List of content blocks containing the message data. + tool_calls: Optional list of tool calls made by the AI. + invalid_tool_calls: Optional list of tool calls that failed validation. + usage: Optional dictionary containing usage statistics. + """ + + type: Literal["ai"] = "ai" + + name: Optional[str] = None + """An optional name for the message. + + This can be used to provide a human-readable name for the message. + + Usage of this field is optional, and whether it's used or not is up to the + model implementation. + """ + + id: Optional[str] = None + """Unique identifier for the message. + + If the provider assigns a meaningful ID, it should be used here. + """ + + lc_version: str = "v1" + """Encoding version for the message.""" + + content: list[types.ContentBlock] = field(default_factory=list) + + usage_metadata: Optional[UsageMetadata] = None + """If provided, usage metadata for a message, such as token counts.""" + + response_metadata: dict = field(default_factory=dict) + """Metadata about the response. + + This field should include non-standard data returned by the provider, such as + response headers, service tiers, or log probabilities. + """ + + def __init__( + self, + content: Union[str, list[types.ContentBlock]], + id: Optional[str] = None, + name: Optional[str] = None, + lc_version: str = "v1", + response_metadata: Optional[dict] = None, + usage_metadata: Optional[UsageMetadata] = None, + ): + """Initialize an AI message. + + Args: + content: Message content as string or list of content blocks. + id: Optional unique identifier for the message. + name: Optional human-readable name for the message. + lc_version: Encoding version for the message. + response_metadata: Optional metadata about the response. + usage_metadata: Optional metadata about token usage. + """ + if isinstance(content, str): + self.content = [{"type": "text", "text": content}] + else: + self.content = content + + self.id = id + self.name = name + self.lc_version = lc_version + self.usage_metadata = usage_metadata + if response_metadata is None: + self.response_metadata = {} + else: + self.response_metadata = response_metadata + + self._tool_calls: list[types.ToolCall] = [] + self._invalid_tool_calls: list[types.InvalidToolCall] = [] + + @property + def text(self) -> Optional[str]: + """Extract all text content from the AI message as a string.""" + text_blocks = [block for block in self.content if block["type"] == "text"] + if text_blocks: + return "".join(block["text"] for block in text_blocks) + return None + + @property + def tool_calls(self) -> list[types.ToolCall]: # update once we fix branch + """Get the tool calls made by the AI.""" + if self._tool_calls: + return self._tool_calls + tool_calls = [block for block in self.content if block["type"] == "tool_call"] + if tool_calls: + self._tool_calls = tool_calls + return self._tool_calls + + @tool_calls.setter + def tool_calls(self, value: list[types.ToolCall]) -> None: + """Set the tool calls for the AI message.""" + self._tool_calls = value + + +@dataclass +class AIMessageChunk: + """A partial chunk of an AI message during streaming. + + Represents a portion of an AI response that is delivered incrementally + during streaming generation. Contains partial content and metadata. + + Attributes: + id: Unique identifier for the message chunk. + type: Message type identifier, always "ai_chunk". + name: Optional human-readable name for the message. + content: List of content blocks containing partial message data. + tool_call_chunks: Optional list of partial tool call data. + usage_metadata: Optional metadata about token usage and costs. + """ + + type: Literal["ai_chunk"] = "ai_chunk" + + name: Optional[str] = None + """An optional name for the message. + + This can be used to provide a human-readable name for the message. + + Usage of this field is optional, and whether it's used or not is up to the + model implementation. + """ + + id: Optional[str] = None + """Unique identifier for the message chunk. + + If the provider assigns a meaningful ID, it should be used here. + """ + + lc_version: str = "v1" + """Encoding version for the message.""" + + content: list[types.ContentBlock] = field(init=False) + + usage_metadata: Optional[UsageMetadata] = None + """If provided, usage metadata for a message chunk, such as token counts. + + These data represent incremental usage statistics, as opposed to a running total. + """ + + response_metadata: dict = field(init=False) + """Metadata about the response chunk. + + This field should include non-standard data returned by the provider, such as + response headers, service tiers, or log probabilities. + """ + + tool_call_chunks: list[types.ToolCallChunk] = field(init=False) + + def __init__( + self, + content: Union[str, list[types.ContentBlock]], + id: Optional[str] = None, + name: Optional[str] = None, + lc_version: str = "v1", + response_metadata: Optional[dict] = None, + usage_metadata: Optional[UsageMetadata] = None, + tool_call_chunks: Optional[list[types.ToolCallChunk]] = None, + ): + """Initialize an AI message. + + Args: + content: Message content as string or list of content blocks. + id: Optional unique identifier for the message. + name: Optional human-readable name for the message. + lc_version: Encoding version for the message. + response_metadata: Optional metadata about the response. + usage_metadata: Optional metadata about token usage. + tool_call_chunks: Optional list of partial tool call data. + """ + if isinstance(content, str): + self.content = [{"type": "text", "text": content, "index": 0}] + else: + self.content = content + + self.id = id + self.name = name + self.lc_version = lc_version + self.usage_metadata = usage_metadata + if response_metadata is None: + self.response_metadata = {} + else: + self.response_metadata = response_metadata + if tool_call_chunks is None: + self.tool_call_chunks: list[types.ToolCallChunk] = [] + else: + self.tool_call_chunks = tool_call_chunks + + self._tool_calls: list[types.ToolCall] = [] + self._invalid_tool_calls: list[types.InvalidToolCall] = [] + self._init_tool_calls() + + def _init_tool_calls(self) -> None: + """Initialize tool calls from tool call chunks. + + Args: + values: The values to validate. + + Raises: + ValueError: If the tool call chunks are malformed. + """ + self._tool_calls = [] + self._invalid_tool_calls = [] + if not self.tool_call_chunks: + if self._tool_calls: + self.tool_call_chunks = [ + create_tool_call_chunk( + name=tc["name"], + args=json.dumps(tc["args"]), + id=tc["id"], + index=None, + ) + for tc in self._tool_calls + ] + if self._invalid_tool_calls: + tool_call_chunks = self.tool_call_chunks + tool_call_chunks.extend( + [ + create_tool_call_chunk( + name=tc["name"], args=tc["args"], id=tc["id"], index=None + ) + for tc in self._invalid_tool_calls + ] + ) + self.tool_call_chunks = tool_call_chunks + + tool_calls = [] + invalid_tool_calls = [] + + def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None: + invalid_tool_calls.append( + create_invalid_tool_call( + name=chunk["name"], + args=chunk["args"], + id=chunk["id"], + error=None, + ) + ) + + for chunk in self.tool_call_chunks: + try: + args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type] + if isinstance(args_, dict): + tool_calls.append( + create_tool_call( + name=chunk["name"] or "", + args=args_, + id=chunk["id"], + ) + ) + else: + add_chunk_to_invalid_tool_calls(chunk) + except Exception: + add_chunk_to_invalid_tool_calls(chunk) + self._tool_calls = tool_calls + self._invalid_tool_calls = invalid_tool_calls + + @property + def text(self) -> Optional[str]: + """Extract all text content from the AI message as a string.""" + text_blocks = [block for block in self.content if block["type"] == "text"] + if text_blocks: + return "".join(block["text"] for block in text_blocks) + return None + + @property + def reasoning(self) -> Optional[str]: + """Extract all reasoning text from the AI message as a string.""" + text_blocks = [block for block in self.content if block["type"] == "reasoning"] + if text_blocks: + return "".join(block["reasoning"] for block in text_blocks) + return None + + @property + def tool_calls(self) -> list[types.ToolCall]: + """Get the tool calls made by the AI.""" + if self._tool_calls: + return self._tool_calls + tool_calls = [block for block in self.content if block["type"] == "tool_call"] + if tool_calls: + self._tool_calls = tool_calls + return self._tool_calls + + @tool_calls.setter + def tool_calls(self, value: list[types.ToolCall]) -> None: + """Set the tool calls for the AI message.""" + self._tool_calls = value + + def __add__(self, other: Any) -> "AIMessageChunk": + """Add AIMessageChunk to this one.""" + if isinstance(other, AIMessageChunk): + return add_ai_message_chunks(self, other) + if isinstance(other, (list, tuple)) and all( + isinstance(o, AIMessageChunk) for o in other + ): + return add_ai_message_chunks(self, *other) + error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk." + raise NotImplementedError(error_msg) + + +def add_ai_message_chunks( + left: AIMessageChunk, *others: AIMessageChunk +) -> AIMessageChunk: + """Add multiple AIMessageChunks together.""" + if not others: + return left + content = merge_content( + cast("list[str | dict[Any, Any]]", left.content), + *(cast("list[str | dict[Any, Any]]", o.content) for o in others), + ) + response_metadata = merge_dicts( + left.response_metadata, *(o.response_metadata for o in others) + ) + + # Merge tool call chunks + if raw_tool_calls := merge_lists( + left.tool_call_chunks, *(o.tool_call_chunks for o in others) + ): + tool_call_chunks = [ + create_tool_call_chunk( + name=rtc.get("name"), + args=rtc.get("args"), + index=rtc.get("index"), + id=rtc.get("id"), + ) + for rtc in raw_tool_calls + ] + else: + tool_call_chunks = [] + + # Token usage + if left.usage_metadata or any(o.usage_metadata is not None for o in others): + usage_metadata: Optional[UsageMetadata] = left.usage_metadata + for other in others: + usage_metadata = add_usage(usage_metadata, other.usage_metadata) + else: + usage_metadata = None + + chunk_id = None + candidates = [left.id] + [o.id for o in others] + # first pass: pick the first non-run-* id + for id_ in candidates: + if id_ and not id_.startswith(_LC_ID_PREFIX): + chunk_id = id_ + break + else: + # second pass: no provider-assigned id found, just take the first non-null + for id_ in candidates: + if id_: + chunk_id = id_ + break + + return left.__class__( + content=cast("list[types.ContentBlock]", content), + tool_call_chunks=tool_call_chunks, + response_metadata=response_metadata, + usage_metadata=usage_metadata, + id=chunk_id, + ) + + +@dataclass +class HumanMessage: + """A message from a human user. + + Represents input from a human user in a conversation, containing text + or other content types like images. + + Attributes: + id: Unique identifier for the message. + content: List of content blocks containing the user's input. + name: Optional human-readable name for the message. + type: Message type identifier, always "human". + """ + + id: str + content: list[types.ContentBlock] + name: Optional[str] = None + """An optional name for the message. + + This can be used to provide a human-readable name for the message. + + Usage of this field is optional, and whether it's used or not is up to the + model implementation. + """ + type: Literal["human"] = "human" + """The type of the message. Must be a string that is unique to the message type. + + The purpose of this field is to allow for easy identification of the message type + when deserializing messages. + """ + + def __init__( + self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None + ): + """Initialize a human message. + + Args: + content: Message content as string or list of content blocks. + id: Optional unique identifier for the message. + """ + self.id = _ensure_id(id) + if isinstance(content, str): + self.content = [{"type": "text", "text": content}] + else: + self.content = content + + def text(self) -> str: + """Extract all text content from the message. + + Returns: + Concatenated string of all text blocks in the message. + """ + return "".join( + block["text"] for block in self.content if block["type"] == "text" + ) + + +@dataclass +class SystemMessage: + """A system message containing instructions or context. + + Represents system-level instructions or context that guides the AI's + behavior and understanding of the conversation. + + Attributes: + id: Unique identifier for the message. + content: List of content blocks containing system instructions. + type: Message type identifier, always "system". + """ + + id: str + content: list[types.ContentBlock] + type: Literal["system"] = "system" + + def __init__( + self, content: Union[str, list[types.ContentBlock]], *, id: Optional[str] = None + ): + """Initialize a system message. + + Args: + content: System instructions as string or list of content blocks. + id: Optional unique identifier for the message. + """ + self.id = _ensure_id(id) + if isinstance(content, str): + self.content = [{"type": "text", "text": content}] + else: + self.content = content + + def text(self) -> str: + """Extract all text content from the system message.""" + return "".join( + block["text"] for block in self.content if block["type"] == "text" + ) + + +@dataclass +class ToolMessage: + """A message containing the result of a tool execution. + + Represents the output from executing a tool or function call, + including the result data and execution status. + + Attributes: + id: Unique identifier for the message. + tool_call_id: ID of the tool call this message responds to. + content: The result content from tool execution. + artifact: Optional app-side payload not intended for the model. + status: Execution status ("success" or "error"). + type: Message type identifier, always "tool". + """ + + id: str + tool_call_id: str + content: list[dict[str, Any]] + artifact: Optional[Any] = None # App-side payload not for the model + status: Literal["success", "error"] = "success" + type: Literal["tool"] = "tool" + + @property + def text(self) -> str: + """Extract all text content from the tool message.""" + return "".join( + block["text"] for block in self.content if block["type"] == "text" + ) + + def __post_init__(self) -> None: + """Initialize computed fields after dataclass creation. + + Ensures the tool message has a valid ID. + """ + self.id = _ensure_id(self.id) + + +# Alias for a message type that can be any of the defined message types +MessageV1 = Union[ + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, +] +MessageV1Types = get_args(MessageV1) diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py index 876e66b5556..0c864805b93 100644 --- a/libs/core/langchain_core/output_parsers/transform.py +++ b/libs/core/langchain_core/output_parsers/transform.py @@ -32,7 +32,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]): def _transform( self, - input: Iterator[Union[str, BaseMessage]], # noqa: A002 + input: Iterator[Union[str, BaseMessage]], ) -> Iterator[T]: for chunk in input: if isinstance(chunk, BaseMessage): @@ -42,7 +42,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]): async def _atransform( self, - input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002 + input: AsyncIterator[Union[str, BaseMessage]], ) -> AsyncIterator[T]: async for chunk in input: if isinstance(chunk, BaseMessage): diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 5061cd3c5d0..e205afd4d8c 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -731,7 +731,7 @@ class Runnable(ABC, Generic[Input, Output]): @abstractmethod def invoke( self, - input: Input, # noqa: A002 + input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Output: @@ -751,7 +751,7 @@ class Runnable(ABC, Generic[Input, Output]): async def ainvoke( self, - input: Input, # noqa: A002 + input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Output: @@ -999,7 +999,7 @@ class Runnable(ABC, Generic[Input, Output]): def stream( self, - input: Input, # noqa: A002 + input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: @@ -1019,7 +1019,7 @@ class Runnable(ABC, Generic[Input, Output]): async def astream( self, - input: Input, # noqa: A002 + input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: @@ -1073,7 +1073,7 @@ class Runnable(ABC, Generic[Input, Output]): async def astream_log( self, - input: Any, # noqa: A002 + input: Any, config: Optional[RunnableConfig] = None, *, diff: bool = True, @@ -1144,7 +1144,7 @@ class Runnable(ABC, Generic[Input, Output]): async def astream_events( self, - input: Any, # noqa: A002 + input: Any, config: Optional[RunnableConfig] = None, *, version: Literal["v1", "v2"] = "v2", @@ -1410,7 +1410,7 @@ class Runnable(ABC, Generic[Input, Output]): def transform( self, - input: Iterator[Input], # noqa: A002 + input: Iterator[Input], config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: @@ -1452,7 +1452,7 @@ class Runnable(ABC, Generic[Input, Output]): async def atransform( self, - input: AsyncIterator[Input], # noqa: A002 + input: AsyncIterator[Input], config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 4ac7bda7b46..cc36622b914 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -402,7 +402,7 @@ def call_func_with_variable_args( Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], ], - input: Input, # noqa: A002 + input: Input, config: RunnableConfig, run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs: Any, @@ -439,7 +439,7 @@ def acall_func_with_variable_args( Awaitable[Output], ], ], - input: Input, # noqa: A002 + input: Input, config: RunnableConfig, run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs: Any, diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 3e22494bad7..20a841d51a8 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -114,7 +114,7 @@ class Node(NamedTuple): def copy( self, *, - id: Optional[str] = None, # noqa: A002 + id: Optional[str] = None, name: Optional[str] = None, ) -> Node: """Return a copy of the node with optional new id and name. @@ -187,7 +187,7 @@ class MermaidDrawMethod(Enum): def node_data_str( - id: str, # noqa: A002 + id: str, data: Union[type[BaseModel], RunnableType, None], ) -> str: """Convert the data of a node to a string. @@ -328,7 +328,7 @@ class Graph: def add_node( self, data: Union[type[BaseModel], RunnableType, None], - id: Optional[str] = None, # noqa: A002 + id: Optional[str] = None, *, metadata: Optional[dict[str, Any]] = None, ) -> Node: diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index d7059fded47..b69e2c331fa 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -616,7 +616,7 @@ def convert_to_json_schema( @beta() def tool_example_to_messages( - input: str, # noqa: A002 + input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None, *, diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 9b1686659ef..66a4c5bc3b4 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -86,6 +86,7 @@ ignore = [ "FIX002", # Line contains TODO "ISC001", # Messes with the formatter "PERF203", # Rarely useful + "PLC0414", # Enable re-export "PLR09", # Too many something (arg, statements, etc) "RUF012", # Doesn't play well with Pydantic "TC001", # Doesn't play well with Pydantic @@ -105,6 +106,7 @@ unfixable = ["PLW1510",] flake8-annotations.allow-star-arg-any = true flake8-annotations.mypy-init-return = true +flake8-builtins.ignorelist = ["id", "input", "type"] flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"] pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",] pydocstyle.convention = "google" diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index ce262797535..4d17f3b8cae 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -15,6 +15,8 @@ from langchain_core.language_models import ( ParrotFakeChatModel, ) from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import MessageV1 from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.stubs import ( _any_id_ai_message, @@ -157,13 +159,13 @@ async def test_callback_handlers() -> None: """Verify that model is implemented correctly with handlers working.""" class MyCustomAsyncHandler(AsyncCallbackHandler): - def __init__(self, store: list[str]) -> None: + def __init__(self, store: list[Union[str, AIMessageChunkV1]]) -> None: self.store = store async def on_chat_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + messages: Union[list[list[BaseMessage]], list[list[MessageV1]]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -178,9 +180,11 @@ async def test_callback_handlers() -> None: @override async def on_llm_new_token( self, - token: str, + token: Union[str, AIMessageChunkV1], *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: Optional[ + Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1] + ] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[list[str]] = None, @@ -194,7 +198,7 @@ async def test_callback_handlers() -> None: ] ) model = GenericFakeChatModel(messages=infinite_cycle) - tokens: list[str] = [] + tokens: list[Union[str, AIMessageChunkV1]] = [] # New model results = [ chunk diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index 1cceb0a146b..d611fc36c4e 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -300,8 +300,9 @@ def test_llm_representation_for_serializable() -> None: assert chat._get_llm_string() == ( '{"id": ["tests", "unit_tests", "language_models", "chat_models", ' '"test_cache", "CustomChat"], "kwargs": {"messages": {"id": ' - '["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": ' - '1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]' + '["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}, ' + '"output_version": "v0"}, "lc": 1, "name": "CustomChat", "type": ' + "\"constructor\"}---[('stop', None)]" ) diff --git a/libs/core/tests/unit_tests/messages/test_imports.py b/libs/core/tests/unit_tests/messages/test_imports.py index ff9fbf92fc7..9fda5493244 100644 --- a/libs/core/tests/unit_tests/messages/test_imports.py +++ b/libs/core/tests/unit_tests/messages/test_imports.py @@ -5,22 +5,40 @@ EXPECTED_ALL = [ "_message_from_dict", "AIMessage", "AIMessageChunk", + "Annotation", "AnyMessage", + "AudioContentBlock", "BaseMessage", "BaseMessageChunk", + "ContentBlock", "ChatMessage", "ChatMessageChunk", + "Citation", + "CodeInterpreterCall", + "CodeInterpreterOutput", + "CodeInterpreterResult", + "DataContentBlock", + "FileContentBlock", "FunctionMessage", "FunctionMessageChunk", "HumanMessage", "HumanMessageChunk", + "ImageContentBlock", "InvalidToolCall", + "NonStandardAnnotation", + "NonStandardContentBlock", + "PlainTextContentBlock", "SystemMessage", "SystemMessageChunk", + "TextContentBlock", "ToolCall", "ToolCallChunk", "ToolMessage", "ToolMessageChunk", + "VideoContentBlock", + "WebSearchCall", + "WebSearchResult", + "ReasoningContentBlock", "RemoveMessage", "convert_to_messages", "get_buffer_string", diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index bedd518589e..f9f1c9c9ff0 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1221,15 +1221,30 @@ def test_convert_to_openai_messages_multimodal() -> None: {"type": "text", "text": "Text message"}, { "type": "image", - "source_type": "url", "url": "https://example.com/test.png", }, + { + "type": "image", + "source_type": "url", # backward compatibility + "url": "https://example.com/test.png", + }, + { + "type": "image", + "base64": "", + "mime_type": "image/png", + }, { "type": "image", "source_type": "base64", "data": "", "mime_type": "image/png", }, + { + "type": "file", + "base64": "", + "mime_type": "application/pdf", + "filename": "test.pdf", + }, { "type": "file", "source_type": "base64", @@ -1244,11 +1259,20 @@ def test_convert_to_openai_messages_multimodal() -> None: "file_data": "data:application/pdf;base64,", }, }, + { + "type": "file", + "file_id": "file-abc123", + }, { "type": "file", "source_type": "id", "id": "file-abc123", }, + { + "type": "audio", + "base64": "", + "mime_type": "audio/wav", + }, { "type": "audio", "source_type": "base64", @@ -1268,7 +1292,7 @@ def test_convert_to_openai_messages_multimodal() -> None: result = convert_to_openai_messages(messages, text_format="block") assert len(result) == 1 message = result[0] - assert len(message["content"]) == 8 + assert len(message["content"]) == 13 # Test adding filename messages = [ @@ -1276,8 +1300,7 @@ def test_convert_to_openai_messages_multimodal() -> None: content=[ { "type": "file", - "source_type": "base64", - "data": "", + "base64": "", "mime_type": "application/pdf", }, ] diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 7c07416fe5d..7d844b6beca 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -785,6 +785,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -1015,6 +1016,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -1029,6 +1034,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', @@ -2217,6 +2223,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -2447,6 +2454,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -2461,6 +2472,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index a788c425fce..7b8f1b570a5 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -1188,6 +1188,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -1418,6 +1419,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -1432,6 +1437,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index 079e4909061..cc9f7aab15d 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -2732,6 +2732,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -2960,6 +2961,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -2973,6 +2978,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', @@ -4208,6 +4214,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -4455,6 +4462,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -4468,6 +4479,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', @@ -5715,6 +5727,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -5962,6 +5975,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -5975,6 +5992,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', @@ -7097,6 +7115,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -7325,6 +7344,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -7338,6 +7361,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', @@ -8615,6 +8639,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -8862,6 +8887,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -8875,6 +8904,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', @@ -10042,6 +10072,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -10270,6 +10301,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -10283,6 +10318,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', @@ -11468,6 +11504,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -11726,6 +11763,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -11739,6 +11780,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', @@ -12936,6 +12978,7 @@ 'args', 'id', 'error', + 'type', ]), 'title': 'InvalidToolCall', 'type': 'object', @@ -13183,6 +13226,10 @@ ]), 'title': 'Id', }), + 'index': dict({ + 'title': 'Index', + 'type': 'integer', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -13196,6 +13243,7 @@ 'name', 'args', 'id', + 'type', ]), 'title': 'ToolCall', 'type': 'object', diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 58822433d7a..75bbb804c2e 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -30,9 +30,11 @@ from langchain_core.messages import ( messages_from_dict, messages_to_dict, ) +from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.utils._merge import merge_lists @@ -195,6 +197,116 @@ def test_message_chunks() -> None: assert (meaningful_id + default_id).id == "msg_def456" +def test_message_chunks_v2() -> None: + left = AIMessageChunkV1("foo ", id="abc") + right = AIMessageChunkV1("bar") + expected = AIMessageChunkV1("foo bar", id="abc") + assert left + right == expected + + # Test tool calls + one = AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk(name="tool1", args="", id="1", index=0) + ], + ) + two = AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk(name=None, args='{"arg1": "val', id=None, index=0) + ], + ) + three = AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk(name=None, args='ue}"', id=None, index=0) + ], + ) + expected = AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk( + name="tool1", args='{"arg1": "value}"', id="1", index=0 + ) + ], + ) + assert one + two + three == expected + + assert ( + AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk(name="tool1", args="", id="1", index=0) + ], + ) + + AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk(name="tool1", args="a", id=None, index=1) + ], + ) + # Don't merge if `index` field does not match. + ) == AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk(name="tool1", args="", id="1", index=0), + create_tool_call_chunk(name="tool1", args="a", id=None, index=1), + ], + ) + + ai_msg_chunk = AIMessageChunkV1([]) + tool_calls_msg_chunk = AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk(name="tool1", args="a", id=None, index=1) + ], + ) + assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk + assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk + + ai_msg_chunk = AIMessageChunkV1( + [], + tool_call_chunks=[ + create_tool_call_chunk(name="tool1", args="", id="1", index=0) + ], + ) + assert ai_msg_chunk.tool_calls == [create_tool_call(name="tool1", args={}, id="1")] + + # Test token usage + left = AIMessageChunkV1( + [], + usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + ) + right = AIMessageChunkV1( + [], + usage_metadata={"input_tokens": 4, "output_tokens": 5, "total_tokens": 9}, + ) + assert left + right == AIMessageChunkV1( + content=[], + usage_metadata={"input_tokens": 5, "output_tokens": 7, "total_tokens": 12}, + ) + assert AIMessageChunkV1(content=[]) + left == left + assert right + AIMessageChunkV1(content=[]) == right + + # Test ID order of precedence + null_id = AIMessageChunkV1(content=[], id=None) + default_id = AIMessageChunkV1( + content=[], id="run-abc123" + ) # LangChain-assigned run ID + meaningful_id = AIMessageChunkV1( + content=[], id="msg_def456" + ) # provider-assigned ID + + assert (null_id + default_id).id == "run-abc123" + assert (default_id + null_id).id == "run-abc123" + + assert (null_id + meaningful_id).id == "msg_def456" + assert (meaningful_id + null_id).id == "msg_def456" + + assert (default_id + meaningful_id).id == "msg_def456" + assert (meaningful_id + default_id).id == "msg_def456" + + def test_chat_message_chunks() -> None: assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk( role="User", content=" indeed." @@ -1111,23 +1223,20 @@ def test_is_data_content_block() -> None: assert is_data_content_block( { "type": "image", - "source_type": "url", "url": "https://...", } ) assert is_data_content_block( { "type": "image", - "source_type": "base64", - "data": "", + "base64": "", "mime_type": "image/jpeg", } ) assert is_data_content_block( { "type": "image", - "source_type": "base64", - "data": "", + "base64": "", "mime_type": "image/jpeg", "cache_control": {"type": "ephemeral"}, } @@ -1135,13 +1244,17 @@ def test_is_data_content_block() -> None: assert is_data_content_block( { "type": "image", - "source_type": "base64", - "data": "", + "base64": "", "mime_type": "image/jpeg", "metadata": {"cache_control": {"type": "ephemeral"}}, } ) - + assert is_data_content_block( + { + "type": "image", + "source_type": "base64", # backward compatibility + } + ) assert not is_data_content_block( { "type": "text", @@ -1154,12 +1267,6 @@ def test_is_data_content_block() -> None: "image_url": {"url": "https://..."}, } ) - assert not is_data_content_block( - { - "type": "image", - "source_type": "base64", - } - ) assert not is_data_content_block( { "type": "image", @@ -1169,31 +1276,65 @@ def test_is_data_content_block() -> None: def test_convert_to_openai_image_block() -> None: - input_block = { - "type": "image", - "source_type": "url", - "url": "https://...", - "cache_control": {"type": "ephemeral"}, - } - expected = { - "type": "image_url", - "image_url": {"url": "https://..."}, - } - result = convert_to_openai_image_block(input_block) - assert result == expected - - input_block = { - "type": "image", - "source_type": "base64", - "data": "", - "mime_type": "image/jpeg", - "cache_control": {"type": "ephemeral"}, - } - expected = { - "type": "image_url", - "image_url": { - "url": "data:image/jpeg;base64,", + for input_block in [ + { + "type": "image", + "url": "https://...", + "cache_control": {"type": "ephemeral"}, }, - } - result = convert_to_openai_image_block(input_block) - assert result == expected + { + "type": "image", + "source_type": "url", + "url": "https://...", + "cache_control": {"type": "ephemeral"}, + }, + ]: + expected = { + "type": "image_url", + "image_url": {"url": "https://..."}, + } + result = convert_to_openai_image_block(input_block) + assert result == expected + + for input_block in [ + { + "type": "image", + "base64": "", + "mime_type": "image/jpeg", + "cache_control": {"type": "ephemeral"}, + }, + { + "type": "image", + "source_type": "base64", + "data": "", + "mime_type": "image/jpeg", + "cache_control": {"type": "ephemeral"}, + }, + ]: + expected = { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,", + }, + } + result = convert_to_openai_image_block(input_block) + assert result == expected + + +def test_known_block_types() -> None: + assert { + "text", + "text-plain", + "tool_call", + "reasoning", + "non_standard", + "image", + "audio", + "file", + "video", + "code_interpreter_call", + "code_interpreter_output", + "code_interpreter_result", + "web_search_call", + "web_search_result", + } == KNOWN_BLOCK_TYPES diff --git a/libs/langchain/langchain/agents/output_parsers/tools.py b/libs/langchain/langchain/agents/output_parsers/tools.py index 10ebd3c3dfc..62759707b38 100644 --- a/libs/langchain/langchain/agents/output_parsers/tools.py +++ b/libs/langchain/langchain/agents/output_parsers/tools.py @@ -47,7 +47,12 @@ def parse_ai_message_to_tool_action( try: args = json.loads(function["arguments"] or "{}") tool_calls.append( - ToolCall(name=function_name, args=args, id=tool_call["id"]), + ToolCall( + name=function_name, + args=args, + id=tool_call["id"], + type="tool_call", + ) ) except JSONDecodeError as e: msg = ( diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py index 04d33d12f4e..089f31f47f7 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py @@ -53,7 +53,12 @@ def test_calls_convert_agent_action_to_messages() -> None: message4 = AIMessage( content="", tool_calls=[ - ToolCall(name="exponentiate", args={"a": 3, "b": 5}, id="call_abc02468"), + ToolCall( + name="exponentiate", + args={"a": 3, "b": 5}, + id="call_abc02468", + type="tool_call", + ), ], ) actions4 = parse_ai_message_to_openai_tool_action(message4) diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index ea6c455db69..d1bc7e6c2a0 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -1008,7 +1008,7 @@ def _make_tools_invocation(name_to_arguments: dict[str, dict[str, Any]]) -> AIMe for idx, (name, arguments) in enumerate(name_to_arguments.items()) ] tool_calls = [ - ToolCall(name=name, args=args, id=str(idx)) + ToolCall(name=name, args=args, id=str(idx), type="tool_call") for idx, (name, args) in enumerate(name_to_arguments.items()) ] return AIMessage( diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index 18279c95e67..9b6a6b44321 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -205,6 +205,7 @@ def test_configurable_with_default() -> None: "name": None, "bound": { "name": None, + "output_version": "v0", "disable_streaming": False, "model": "claude-3-sonnet-20240229", "mcp_servers": None, diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py index 25ff3eb607c..a35a88884dd 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_compat.py +++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py @@ -1,7 +1,10 @@ """ -This module converts between AIMessage output formats for the Responses API. +This module converts between AIMessage output formats, which are governed by the +``output_version`` attribute on ChatOpenAI. Supported values are ``"v0"``, +``"responses/v1"``, and ``"v1"``. -ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs: +``"v0"`` corresponds to the format as of ChatOpenAI v0.3. For the Responses API, it +stores reasoning and tool outputs in AIMessage.additional_kwargs: .. code-block:: python @@ -28,8 +31,9 @@ ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs id="msg_123", ) -To retain information about response item sequencing (and to accommodate multiple -reasoning items), ChatOpenAI now stores these items in the content sequence: +``"responses/v1"`` is only applicable to the Responses API. It retains information +about response item sequencing and accommodates multiple reasoning items by +representing these items in the content sequence: .. code-block:: python @@ -56,19 +60,23 @@ reasoning items), ChatOpenAI now stores these items in the content sequence: There are other, small improvements as well-- e.g., we store message IDs on text content blocks, rather than on the AIMessage.id, which now stores the response ID. +``"v1"`` represents LangChain's cross-provider standard format. + For backwards compatibility, this module provides functions to convert between the old and new formats. The functions are used internally by ChatOpenAI. """ # noqa: E501 import json -from typing import Union +from collections.abc import Iterable, Iterator +from typing import Any, Literal, Union, cast -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__" +# v0.3 / Responses def _convert_to_v03_ai_message( message: AIMessage, has_reasoning: bool = False ) -> AIMessage: @@ -253,3 +261,423 @@ def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage: }, deep=False, ) + + +# v1 / Chat Completions +def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage: + """Mutate a Chat Completions message to v1 format.""" + if isinstance(message.content, str): + if message.content: + message.content = [{"type": "text", "text": message.content}] + else: + message.content = [] + + for tool_call in message.tool_calls: + if id_ := tool_call.get("id"): + message.content.append({"type": "tool_call", "id": id_}) + + if "tool_calls" in message.additional_kwargs: + _ = message.additional_kwargs.pop("tool_calls") + + if "token_usage" in message.response_metadata: + _ = message.response_metadata.pop("token_usage") + + return message + + +def _convert_to_v1_from_chat_completions_chunk(chunk: AIMessageChunk) -> AIMessageChunk: + result = _convert_to_v1_from_chat_completions(cast(AIMessage, chunk)) + return cast(AIMessageChunk, result) + + +def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage: + """Convert a v1 message to the Chat Completions format.""" + if isinstance(message.content, list): + new_content: list = [] + for block in message.content: + if isinstance(block, dict): + block_type = block.get("type") + if block_type == "text": + # Strip annotations + new_content.append({"type": "text", "text": block["text"]}) + elif block_type in ("reasoning", "tool_call"): + pass + else: + new_content.append(block) + else: + new_content.append(block) + return message.model_copy(update={"content": new_content}) + + return message + + +# v1 / Responses +def _convert_annotation_to_v1(annotation: dict[str, Any]) -> dict[str, Any]: + annotation_type = annotation.get("type") + + if annotation_type == "url_citation": + url_citation = {} + for field in ("end_index", "start_index", "title"): + if field in annotation: + url_citation[field] = annotation[field] + url_citation["type"] = "url_citation" + url_citation["url"] = annotation["url"] + return url_citation + + elif annotation_type == "file_citation": + document_citation = {"type": "document_citation"} + if "filename" in annotation: + document_citation["title"] = annotation["filename"] + for field in ("file_id", "index"): # OpenAI-specific + if field in annotation: + document_citation[field] = annotation[field] + return document_citation + + # TODO: standardise container_file_citation? + else: + non_standard_annotation = { + "type": "non_standard_annotation", + "value": annotation, + } + return non_standard_annotation + + +def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]: + if block.get("type") != "reasoning" or "summary" not in block: + yield block + return + + if not block["summary"]: + _ = block.pop("summary", None) + yield block + return + + # Common part for every exploded line, except 'summary' + common = {k: v for k, v in block.items() if k != "summary"} + + # Optional keys that must appear only in the first exploded item + first_only = { + k: common.pop(k) for k in ("encrypted_content", "status") if k in common + } + + for idx, part in enumerate(block["summary"]): + new_block = dict(common) + new_block["reasoning"] = part.get("text", "") + if idx == 0: + new_block.update(first_only) + yield new_block + + +def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: + """Mutate a Responses message to v1 format.""" + if not isinstance(message.content, list): + return message + + def _iter_blocks() -> Iterable[dict[str, Any]]: + for block in message.content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + + if block_type == "text": + if "annotations" in block: + block["annotations"] = [ + _convert_annotation_to_v1(a) for a in block["annotations"] + ] + yield block + + elif block_type == "reasoning": + yield from _explode_reasoning(block) + + elif block_type == "image_generation_call" and ( + result := block.get("result") + ): + new_block = {"type": "image", "base64": result} + if output_format := block.get("output_format"): + new_block["mime_type"] = f"image/{output_format}" + for extra_key in ( + "id", + "index", + "status", + "background", + "output_format", + "quality", + "revised_prompt", + "size", + ): + if extra_key in block: + new_block[extra_key] = block[extra_key] + yield new_block + + elif block_type == "function_call": + new_block = {"type": "tool_call", "id": block.get("call_id", "")} + if "id" in block: + new_block["item_id"] = block["id"] + for extra_key in ("arguments", "name", "index"): + if extra_key in block: + new_block[extra_key] = block[extra_key] + yield new_block + + elif block_type == "web_search_call": + web_search_call = {"type": "web_search_call", "id": block["id"]} + if "index" in block: + web_search_call["index"] = block["index"] + if ( + "action" in block + and isinstance(block["action"], dict) + and block["action"].get("type") == "search" + and "query" in block["action"] + ): + web_search_call["query"] = block["action"]["query"] + for key in block: + if key not in ("type", "id"): + web_search_call[key] = block[key] + + web_search_result = {"type": "web_search_result", "id": block["id"]} + if "index" in block: + web_search_result["index"] = block["index"] + 1 + yield web_search_call + yield web_search_result + + elif block_type == "code_interpreter_call": + code_interpreter_call = { + "type": "code_interpreter_call", + "id": block["id"], + } + if "code" in block: + code_interpreter_call["code"] = block["code"] + if "container_id" in block: + code_interpreter_call["container_id"] = block["container_id"] + if "index" in block: + code_interpreter_call["index"] = block["index"] + + code_interpreter_result = { + "type": "code_interpreter_result", + "id": block["id"], + } + if "outputs" in block: + code_interpreter_result["outputs"] = block["outputs"] + for output in block["outputs"]: + if ( + isinstance(output, dict) + and (output_type := output.get("type")) + and output_type == "logs" + ): + if "output" not in code_interpreter_result: + code_interpreter_result["output"] = [] + code_interpreter_result["output"].append( + { + "type": "code_interpreter_output", + "stdout": output.get("logs", ""), + } + ) + + if "status" in block: + code_interpreter_result["status"] = block["status"] + if "index" in block: + code_interpreter_result["index"] = block["index"] + 1 + + yield code_interpreter_call + yield code_interpreter_result + + else: + new_block = {"type": "non_standard", "value": block} + if "index" in new_block["value"]: + new_block["index"] = new_block["value"].pop("index") + yield new_block + + # Replace the list with the fully converted one + message.content = list(_iter_blocks()) + + return message + + +def _convert_annotation_from_v1(annotation: dict[str, Any]) -> dict[str, Any]: + annotation_type = annotation.get("type") + + if annotation_type == "document_citation": + new_ann: dict[str, Any] = {"type": "file_citation"} + + if "title" in annotation: + new_ann["filename"] = annotation["title"] + + for fld in ("file_id", "index"): + if fld in annotation: + new_ann[fld] = annotation[fld] + + return new_ann + + elif annotation_type == "non_standard_annotation": + return annotation["value"] + + else: + return dict(annotation) + + +def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str, Any]]: + i = 0 + n = len(blocks) + + while i < n: + block = blocks[i] + + # Skip non-reasoning blocks or blocks already in Responses format + if block.get("type") != "reasoning" or "summary" in block: + yield dict(block) + i += 1 + continue + elif "reasoning" not in block and "summary" not in block: + # {"type": "reasoning", "id": "rs_..."} + oai_format = {**block, "summary": []} + oai_format["type"] = oai_format.pop("type", "reasoning") + yield oai_format + i += 1 + continue + else: + pass + + summary: list[dict[str, str]] = [ + {"type": "summary_text", "text": block.get("reasoning", "")} + ] + # 'common' is every field except the exploded 'reasoning' + common = {k: v for k, v in block.items() if k != "reasoning"} + + i += 1 + while i < n: + next_ = blocks[i] + if next_.get("type") == "reasoning" and "reasoning" in next_: + summary.append( + {"type": "summary_text", "text": next_.get("reasoning", "")} + ) + i += 1 + else: + break + + merged = dict(common) + merged["summary"] = summary + merged["type"] = merged.pop("type", "reasoning") + yield merged + + +def _consolidate_calls( + items: Iterable[dict[str, Any]], + call_name: Literal["web_search_call", "code_interpreter_call"], + result_name: Literal["web_search_result", "code_interpreter_result"], +) -> Iterator[dict[str, Any]]: + """ + Generator that walks through *items* and, whenever it meets the pair + + {"type": "web_search_call", "id": X, ...} + {"type": "web_search_result", "id": X} + + merges them into + + {"id": X, + "action": …, + "status": …, + "type": "web_search_call"} + + keeping every other element untouched. + """ + items = iter(items) # make sure we have a true iterator + for current in items: + # Only a call can start a pair worth collapsing + if current.get("type") != call_name: + yield current + continue + + try: + nxt = next(items) # look-ahead one element + except StopIteration: # no “result” – just yield the call back + yield current + break + + # If this really is the matching “result” – collapse + if nxt.get("type") == result_name and nxt.get("id") == current.get("id"): + if call_name == "web_search_call": + collapsed = { + "id": current["id"], + "status": current["status"], + "type": "web_search_call", + } + if "action" in current: + collapsed["action"] = current["action"] + + if call_name == "code_interpreter_call": + collapsed = {"id": current["id"]} + for key in ("code", "container_id"): + if key in current: + collapsed[key] = current[key] + + for key in ("outputs", "status"): + if key in nxt: + collapsed[key] = nxt[key] + collapsed["type"] = "code_interpreter_call" + + yield collapsed + + else: + # Not a matching pair – emit both, in original order + yield current + yield nxt + + +def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage: + if not isinstance(message.content, list): + return message + + new_content: list = [] + for block in message.content: + if isinstance(block, dict): + block_type = block.get("type") + if block_type == "text" and "annotations" in block: + # Need a copy because we’re changing the annotations list + new_block = dict(block) + new_block["annotations"] = [ + _convert_annotation_from_v1(a) for a in block["annotations"] + ] + new_content.append(new_block) + elif block_type == "tool_call": + new_block = {"type": "function_call", "call_id": block["id"]} + if "item_id" in block: + new_block["id"] = block["item_id"] + if "name" in block and "arguments" in block: + new_block["name"] = block["name"] + new_block["arguments"] = block["arguments"] + else: + tool_call = next( + call for call in message.tool_calls if call["id"] == block["id"] + ) + if "name" not in block: + new_block["name"] = tool_call["name"] + if "arguments" not in block: + new_block["arguments"] = json.dumps(tool_call["args"]) + new_content.append(new_block) + elif ( + is_data_content_block(block) + and block["type"] == "image" + and "base64" in block + ): + new_block = {"type": "image_generation_call", "result": block["base64"]} + for extra_key in ("id", "status"): + if extra_key in block: + new_block[extra_key] = block[extra_key] + new_content.append(new_block) + elif block_type == "non_standard" and "value" in block: + new_content.append(block["value"]) + else: + new_content.append(block) + else: + new_content.append(block) + + new_content = list(_implode_reasoning_blocks(new_content)) + new_content = list( + _consolidate_calls(new_content, "web_search_call", "web_search_result") + ) + new_content = list( + _consolidate_calls( + new_content, "code_interpreter_call", "code_interpreter_result" + ) + ) + + return message.model_copy(update={"content": new_content}) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 1bc9b66d880..ca4118fb6a3 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -108,7 +108,12 @@ from langchain_openai.chat_models._client_utils import ( ) from langchain_openai.chat_models._compat import ( _convert_from_v03_ai_message, + _convert_from_v1_to_chat_completions, + _convert_from_v1_to_responses, _convert_to_v03_ai_message, + _convert_to_v1_from_chat_completions, + _convert_to_v1_from_chat_completions_chunk, + _convert_to_v1_from_responses, ) if TYPE_CHECKING: @@ -669,7 +674,7 @@ class BaseChatOpenAI(BaseChatModel): .. versionadded:: 0.3.9 """ - output_version: Literal["v0", "responses/v1"] = "v0" + output_version: str = "v0" """Version of AIMessage output format to use. This field is used to roll-out new output formats for chat model AIMessages @@ -680,6 +685,7 @@ class BaseChatOpenAI(BaseChatModel): - ``'v0'``: AIMessage format as of langchain-openai 0.3.x. - ``'responses/v1'``: Formats Responses API output items into AIMessage content blocks. + - ``"v1"``: v1 of LangChain cross-provider standard. Currently only impacts the Responses API. ``output_version='responses/v1'`` is recommended. @@ -869,6 +875,10 @@ class BaseChatOpenAI(BaseChatModel): message=default_chunk_class(content="", usage_metadata=usage_metadata), generation_info=base_generation_info, ) + if self.output_version == "v1": + generation_chunk.message = _convert_to_v1_from_chat_completions_chunk( + cast(AIMessageChunk, generation_chunk.message) + ) return generation_chunk choice = choices[0] @@ -896,6 +906,20 @@ class BaseChatOpenAI(BaseChatModel): if usage_metadata and isinstance(message_chunk, AIMessageChunk): message_chunk.usage_metadata = usage_metadata + if self.output_version == "v1": + message_chunk = cast(AIMessageChunk, message_chunk) + # Convert to v1 format + if isinstance(message_chunk.content, str): + message_chunk = _convert_to_v1_from_chat_completions_chunk( + message_chunk + ) + if message_chunk.content: + message_chunk.content[0]["index"] = 0 # type: ignore[index] + else: + message_chunk = _convert_to_v1_from_chat_completions_chunk( + message_chunk + ) + generation_chunk = ChatGenerationChunk( message=message_chunk, generation_info=generation_info or None ) @@ -1188,7 +1212,12 @@ class BaseChatOpenAI(BaseChatModel): else: payload = _construct_responses_api_payload(messages, payload) else: - payload["messages"] = [_convert_message_to_dict(m) for m in messages] + payload["messages"] = [ + _convert_message_to_dict(_convert_from_v1_to_chat_completions(m)) + if isinstance(m, AIMessage) + else _convert_message_to_dict(m) + for m in messages + ] return payload def _create_chat_result( @@ -1254,6 +1283,11 @@ class BaseChatOpenAI(BaseChatModel): if hasattr(message, "refusal"): generations[0].message.additional_kwargs["refusal"] = message.refusal + if self.output_version == "v1": + _ = llm_output.pop("token_usage", None) + generations[0].message = _convert_to_v1_from_chat_completions( + cast(AIMessage, generations[0].message) + ) return ChatResult(generations=generations, llm_output=llm_output) async def _astream( @@ -3577,6 +3611,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list: for lc_msg in messages: if isinstance(lc_msg, AIMessage): lc_msg = _convert_from_v03_ai_message(lc_msg) + lc_msg = _convert_from_v1_to_responses(lc_msg) msg = _convert_message_to_dict(lc_msg) # "name" parameter unsupported if "name" in msg: @@ -3720,7 +3755,7 @@ def _construct_lc_result_from_responses_api( response: Response, schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None, - output_version: Literal["v0", "responses/v1"] = "v0", + output_version: str = "v0", ) -> ChatResult: """Construct ChatResponse from OpenAI Response API response.""" if response.error: @@ -3859,6 +3894,27 @@ def _construct_lc_result_from_responses_api( ) if output_version == "v0": message = _convert_to_v03_ai_message(message) + elif output_version == "v1": + message = _convert_to_v1_from_responses(message) + if response.tools and any( + tool.type == "image_generation" for tool in response.tools + ): + # Get mime_time from tool definition and add to image generations + # if missing (primarily for tracing purposes). + image_generation_call = next( + tool for tool in response.tools if tool.type == "image_generation" + ) + if image_generation_call.output_format: + mime_type = f"image/{image_generation_call.output_format}" + for content_block in message.content: + # OK to mutate output message + if ( + isinstance(content_block, dict) + and content_block.get("type") == "image" + and "base64" in content_block + and "mime_type" not in block + ): + block["mime_type"] = mime_type else: pass return ChatResult(generations=[ChatGeneration(message=message)]) @@ -3872,7 +3928,7 @@ def _convert_responses_chunk_to_generation_chunk( schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None, has_reasoning: bool = False, - output_version: Literal["v0", "responses/v1"] = "v0", + output_version: str = "v0", ) -> tuple[int, int, int, Optional[ChatGenerationChunk]]: def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None: """Advance indexes tracked during streaming. @@ -3938,9 +3994,29 @@ def _convert_responses_chunk_to_generation_chunk( annotation = chunk.annotation else: annotation = chunk.annotation.model_dump(exclude_none=True, mode="json") - content.append({"annotations": [annotation], "index": current_index}) + if output_version == "v1": + content.append( + { + "type": "text", + "text": "", + "annotations": [annotation], + "index": current_index, + } + ) + else: + content.append({"annotations": [annotation], "index": current_index}) elif chunk.type == "response.output_text.done": - content.append({"id": chunk.item_id, "index": current_index}) + if output_version == "v1": + content.append( + { + "type": "text", + "text": "", + "id": chunk.item_id, + "index": current_index, + } + ) + else: + content.append({"id": chunk.item_id, "index": current_index}) elif chunk.type == "response.created": id = chunk.response.id response_metadata["id"] = chunk.response.id # Backwards compatibility @@ -4016,21 +4092,34 @@ def _convert_responses_chunk_to_generation_chunk( content.append({"type": "refusal", "refusal": chunk.refusal}) elif chunk.type == "response.output_item.added" and chunk.item.type == "reasoning": _advance(chunk.output_index) + current_sub_index = 0 reasoning = chunk.item.model_dump(exclude_none=True, mode="json") reasoning["index"] = current_index content.append(reasoning) elif chunk.type == "response.reasoning_summary_part.added": - _advance(chunk.output_index) - content.append( - { - # langchain-core uses the `index` key to aggregate text blocks. - "summary": [ - {"index": chunk.summary_index, "type": "summary_text", "text": ""} - ], - "index": current_index, - "type": "reasoning", - } - ) + if output_version in ("v0", "responses/v1"): + _advance(chunk.output_index) + content.append( + { + # langchain-core uses the `index` key to aggregate text blocks. + "summary": [ + { + "index": chunk.summary_index, + "type": "summary_text", + "text": "", + } + ], + "index": current_index, + "type": "reasoning", + } + ) + else: + block: dict = {"type": "reasoning", "reasoning": ""} + if chunk.summary_index > 0: + _advance(chunk.output_index, chunk.summary_index) + block["id"] = chunk.item_id + block["index"] = current_index + content.append(block) elif chunk.type == "response.image_generation_call.partial_image": # Partial images are not supported yet. pass @@ -4065,6 +4154,15 @@ def _convert_responses_chunk_to_generation_chunk( AIMessageChunk, _convert_to_v03_ai_message(message, has_reasoning=has_reasoning), ) + elif output_version == "v1": + message = cast(AIMessageChunk, _convert_to_v1_from_responses(message)) + for content_block in message.content: + if ( + isinstance(content_block, dict) + and content_block.get("index", -1) > current_index + ): + # blocks were added for v1 + current_index = content_block["index"] else: pass return ( diff --git a/libs/partners/openai/tests/cassettes/test_function_calling.yaml.gz b/libs/partners/openai/tests/cassettes/test_function_calling.yaml.gz new file mode 100644 index 0000000000000000000000000000000000000000..197a8402cf6ebaf57e31e7d5600363e77502e1c7 GIT binary patch literal 7912 zcmV%O&o{ibqo`^HHrJ|_H%KY-n%&O|Q+0X54y;=Y99}mCSOZWK4 zFaPS$}$p*&v(R;3~I$HQG;7gu()2>Q)l;&ywXsSCj% zvV6O%hWJddrP|s#rZ0B8;huFIS)v`O!c?zq(+=@gJ8EyzFGbPr4p&a+P^HiD*2Jvz zRt`S1rN>|p^p2J>6rJeNn>Fh{(^ z4B~Y#!?Yn1VP-T!n#okJ3YPAMPSC5NqiXuxuoAVO^>H)Ny93P5S%8FaP-U*WpvE-W)%*e*N_yr3MeP*jc(k?A>a)-T2G5 zAenB0WqhNZvy9`J^KqliEZ#hXv)J4CcQo7_SF;<<{N?)AEVt{~+r!jDM>iTB9@d*_ z5Qnq%`eFL!`1dwlA7kgvw&C=qh|5iv>pL5+rp^LWd)Q&oZwtNr)Z1+Hc6@M_+q*^b zhNeL@`&{b}M-|PWRqh|Z{5cqn@#pp}_~X3MzYRXc!=tND*0A!OEgc9c>@p%)v9>iU zDcrCQ@$lBv#nh?|ZS8i#y`AFD!ZOsk6NP(xwzC5o6%4tn9qpj;C^whhY$W4F4c{Sr zGlpSm3~`D@u=<_T(T;dHB$S*Hl$_yD0X>*zykO52NE)UH3VFp$76vuyg^r=EoJd>H z)a2L;7_f!(UEcA|j8>spd123#r=NBxX@OBuL%`~M+OF-Lefj7ATYPes z%SG@vfc*_%_X1c*0QTa!Qah15PoC4MB~xR#a-)f7M_R=>HWIf}JI+Rp2*9>U0jxU& znzgHc9tk^oJ3F;l(McHyv5bu^4MUq=5V?D|DG2sa+n!RQ;|1)xE2 zjow!7dcf!gV{J|cv<8R?F&r~s1BBG}AZB<@%gC9>z%0d^5~Af`HX>5>K3k*j@$!SK zhnVgQX2sZD<^%wh22DBdxE4$!ZSFSd!9#ia6-)2ELbkwYb!k)@V0RYiq65z&*8T>% zzk%-m9niha)~jW_o)tkQ>ukMwBaJ@CKOo?K>B-IBiozk8g-7b%NX$nydXdz~Ow2?P z1}*Pkj(a}n(xaJxIL#eFork`fcn77mNR;tQqPAoy>rxqOfuhS51sK>2>Rw^L#Vc&7 z3RvUctAu#~hv6##7j|7gsLAB0wCa1_<@_b-*V#ycxov=!Ld^)THwZTgtQG*TXSs>b zqV{|$xriRMEt^#*i)qv6!>mgO=d>r_R%v(H8nZcAOhf|m-W9bvXzcTR{0b z8+_qAn!VWRzHHRakBzN|+v!MvykrpPdZj;pdgQ`@5Co^I3G~XWOP#HF-J4}Uso7L# zVh^lHeO2e4WVKAmn!kuxRjhv7SEAlxFs@7LSPPZ0C#cd(MIGnMGZ-rewWSEk+&rkd zd@$!ul8qpJMUL3@5=F13lWk^Nz}ohbtjiBpIc*-bk5)z7K&AqdeY5ww(d<1M5Rf`L zIC301n7EL2`8qmz(;TjKZAjsQWCI=a?atv3$hpPA z^&^77Yq4xum*nKxYId~9q)GEtj8(zLNNEHJ+C(dxIF@ch;gZSnKTKGM{=Ckf;iE*f zH0zQcHh$+LqY-0d0-q4q*PI$OdxB;7g6yT=y}saQS6F9VrX%8p1ww{xAP;& zCYYzX=!r!v40D1^Tbb;$-`BB$ULno}QZAta>~xbYdvk13@-_t?5HH5W^a*oIW<(@j1H6{&Oi-zjmYIh z#fbvIv6A1OO4^oL8@a$-xzfmz8)Ul%l-Yzp?^9M$UKWf7C>2_H+jJJ^%9X=LO?116 zT~C`I2Dwao7tF%|eSl|Ou}HGO*+SKO#KaQ#=9L|}Cxk`G-d@-~$#v-4r>{@c@q#&z z)8$D6;0xYEf*4{6p4ez!Use}ETeJNztebyaW6;eaxO(Sn7A=h_TDsPtHQikQVl;aL z6MyHc%|kPdudaPPew)VYSA-MSY8Fp}hvl0$jd?_SquuIl=^$po+#aI!t%f|9xz>cj5eI=RiEFid z!yIQ0qM6O{kL#h*8k4J>15;Pu@b}KutK7p-dXH!J zz^>taL8qgXmEQ_NGxVg_5C;*>f{{~1jc!x#41Qo5{wt6E%1(-RaV_ZNL&!Z$TAI01 zmDxZ5N3|*J)uRB9!{nHHBg&+T0xz7bZ9Gq}$wygwTVBk$73JhJDtRWS$UYL8vLb{W zLaAxu>LLWdD#GkDmT!aya-(o=X*Eh@k~azvx(GH))uv%fu{?~{QBBy=Maq5B3Ll9u zDWk`IpNAck3yrLugFxRsD^6qS9w}p)2cKCE7ZKmIQ(Ku5OJMr&&O6ez`|`V`JnqMt z-1qA8*D}00Kf9jj{Ys9vH1TVx-dkDTewku3EzwvVo|7s5$h(zzwxYZOHiy!%ekd!teh26M(`FL`W)m6qCpohC~t81OA=Qj9XIy>h(F z2TuyoUI6HPn)s_B)|%c5@&vJOExj_5pMhMC{5+-jjsv)YjS20>0K+7sgrSN7%PY*E z)*OmCY!D(dV`GVrwA=O%Y0$D3>fDOdjx|@uK>BeQnsaxqCcT9@Hy;OMYcXia2Q{fn z$~+9r@q%Y52d$xRj@7xDu#x1EANMglLwS*%gfmgcGSD9Y*m_~m!5sx9H;$>5P*{f* zq^>9xc9E5|fQ=9ex`z=)wPOjr0B~wPaig)O+Vcrw2x~~D#wy(rqHBw(p=~mpyhxf{ z?5!XhZxAX!PIb6^i~_HiPmN)cp2B zjd`6rhdsJx`BRyB$LbvVS!{ZtEM;#|&SDf9jC_e-s?}F4sS}gr`T<^ z#Tg!-_&L#w-}m1aqeo!lL@vTF(p!a>vxxd*B|i-8KI$&t-9>o$bi?f{Y$ZXsU+-By zzC0hSV%f)ES01>yE|H$x#%JO_0I)qb;g+bjGzc0cN>D90d05MeNaZ>dx(3dgFDu5S ziaK5yxIW+AD5bO%EpvlB3z{hu1rmj}#udq$mjY!`)q8Ualqo{C#_k+Tg|5L%^(QH} zo?&z&P|oe$hRkR)E_uRCOK4V)WuX-l|FZ1Pg?3I2z974Ep=DjFqp!>ETxh(@&&iGo zl>vv`bFAg=thD<&qfxJfDjZJwnJRw_!IoGk(TCmLfMkj2D#16!>Ur?IJgXy5L^GmY z1W#3xnF^tp`Vcm(h341_t&X`@=W^AIo~n38F!iUzl$4~Dfywh2Qb=e*C9FN;LFwxL zCeJ7D%I;ifyvuLM?p$a&HTab5DrMQ73yrJ#H)MA%v|@QqWfvSA8R7a@oJH=(Q;`2F zoW-f{x#29lyPXAJo;zIrz0RVf>a$Vr%Z|1zw47~x&(W5JmZSQ)>=vJLv}K`jaQmjC zHCVPabr!`ag}=&K6fF4oS2+tFq)>KxT>rBC%7V*R{PXfto7a_&kMfhccT4H`@_bVF zt|WC19PqL(ui@TT0goHx72LZ3yjZx)S>&U<>KQ3*_i3=>zUM-AwXJo7q$p$gU4`QY z`GxrAJ%!6@7T3G8T~_hNkz=0-i`Z}PgPFERyJ77HW9GSy{4GSaM%4vHz4ye6r-34q z{Yi*_UEh*h@#3{WG0zNsEnd79D86Ss7xALH) zC&j|kNxrM@FYA+1T$cDPIoD&^4{ubd2?s6v9l=($xip&R+`d;KEtf`1j^aBN(sF5(`|;0KNEscdo^RDp)Yen6BjRk= zpO_+a2F<@;-NH9-)cRNU;uY?Uw(JELdo|9xj%gPPJSc`>i5K0oy?6z`EPbtr16G7d z%f7eky~}_8oA1B-x7%R%e`5FC|I!Yz56%*a=G+SDEYV;rS5;k|tGZG(lc&l+HxcZ; zByjy;Z1tm6oGh`%YG8sf{IzF^#J|rH{p-&Xxs`HQ4<<2R>?y2QAX8N`1*PTf)lP4r zj{Q)b=gUKz_LNpHR9cNtkwBay<=wMHXGh&R3A{L9QKslrpXK}-BQ5WnH?t+UJj#&! z?^2iPTTf0e(dZD08reShxUqAjP?#E4h@}kOisFqW&?t%Goccj&)=AFE#c7E2%+P!1 z3AlW1Lyb$;^+=XvCd|4Xi*p9x){3C&+8m)$K!^0oMs{IcXOkr-;Z#!Zc;%cyiEDf3 z5p4OO6=N50`{(C2@)J)Nhg3=0J{eq`Jn`+Qx+6nAx1RQSxAkiWR8C(PYe%+2 z&lqyLaP>4qzGEM-P^#(3tx2L>oQ)}j<;0g6mY|*t#>)C?_c0h4;De29;SMH21C(C& zMzHtUsQ%bb6-boZ2QddX;9axf$crb>j4-ke*ut>^j07!hHg;Cf0Num{iSA2a82IT> zxih#Nb#eX$6sfmPm%|sZ9LXT~v2!bKr5`bZQk_b*goXE`T6QM>2OC*Gadv|yeovG= zg5^ASaTL4bP~OnSo&ID6?!DLynnL=>4( zUFd5p(7{1vRg>j^Q zg(VlN{e+EdryhSi;`c;J>Vn#u@UXK#AvG6C%{Qr<3yS$W0FBn4JtNC<>87xkVsK{pqr zngy2F;iFW7yJ6!%X|ij}LOKf_U6*sc&7LvEsN}9ol@QN~!g$(!@!XETa*mCPrE6~> zjZ3@Y;IkhdUbl0a1RDp;CY(?8_7?4UmouGEW$%IUf@(E3EM3$Tyh-Welrg(F(eyNjJA)pk}>Og}f6Leb4} z8;h=C8HvT&OYIHJ_K_(^sysVVh_7hEiH%<>#8)(-if;5d_*d-9o??7F|l-goszHArJD40 z7DH~P9m=!U$lh0`bO;?RS3b%g2W1Y&i9m1Aa^PIWWfDR^O3N=6`2cJ8Uwx_ZN zV)x&SPSxRlJXro5({65Eg$x@_Iwml!43|`Lvm_rYzx+wf?LhFJgVhUY)f*k6v!%xv ziuN=$D4thCC{~e*ZwstK-&+2B4jLoN$U=Tr6lnM$&cd`a@ie`BHQ;Jb`s&ydv0$Qlsud1n_V8_ zvqGu+-1%NeN>NAF7ORpTj%i1O`~(Q;?7d;8EG&P-R^P>H4>5zr z#L{!hI!*O)Ox_Nbu>#;jG01mAW}}8?M|7uvN~}(8gYH7h7%t@wyO+WzE0Vj~c1oKtMvdbB9!#u5 z6FXpZDpryTE)!^gNj4l?g!Bv#5ovUKRO4A=^~u6$CYY5<2XH*IbZC~W#^_P48?qTL zr-qUnSSm9n<0P{#pKV*vb3a8#FaUh)EDt|=!S3=)VFNlj*~(4dJ`YZek$=uUa$QW7 zP}n)e30Egx?AYPh(T_?$@$fq|t@R?{?+ry`eVs-lGCb^$z4RdMNNfP<5wwtJbv6J( z*nEDx3NqU`h6aAKfkC=EtK-s{o&R!Sy9Rg!bF>rtz8-Z^| z2=JiR{h-kU;`nGPxH|rZuF|TeK-TbKhzkVPjH)VgqO;mOM|D^SMrn>p;FWftS`E># z!fO81o@yPrLyANCc?eB?$VOgf!BQBbi-`Fikmu5fFEEh_gKWF2u4aUzQdO9Cm0a=i zqiUbmzxNa`-{Ssx{d-UGitSvt^zY~M0+WtHz(2sgmH0k)7u6Psjn$1QF-EOTcWlwG8rzb`n=>{zbDEE+zuF{>IwMzQ?CdpQJ z4vQcTzJ}6}-&UYpOBeoUsz0q^629Ao%Ny}gYX!K~(k30tHzlT%e&&3MMIg$M$+;2& zp-r2xm3%yTO{P3fMN8Oe&ZwPlaR)ZiBXISM-(&~6JoxPFaA$+40%IKf7M*@jU+R^j z$k307AgcUkoD0twpEp(TH{2I|%WPDOZMZPbF_}U|iKvs?&9=82aowbL9_IrrvCfy!K_3Zq&M=L)mSpu2BS$E!0 zsxc?>+*Y546=jv$dD<-;NgnCK@aRY?J7SvJHbq>XGOmNbRQ+s=9F}Y-VcF(sesj2U zX$N?|bQq2>?B-n}AVFm1ri!$s{S|bal6($ix*(;ZmBuo{<`uH(IozDj!dAV#28*84 zckoRd+~5)UK}`+|N9Gl!bn+=Da;r=uq$1i)_rWMz=2`VN3tMln;Lk?5JhUyC+qWZJ z9@-Xh&$lC79@^#^>yIPc=TXRXk+KH}nY_Xr8q3u7(=69Jew-_d#rFylG7)(pKh)gQ za!FI-qmjrN^PKHPFnUsPBS5y=mmZJjJuXuKjRo7A4m!(d171akm-CcpZ)GA};wns6 zQ`qfxO@+JA+S(6E&c67hHbzX3O#&5N8)$e)VR{+Ut%L0NgF1Y;a&6Y(S@?u}=S8ulDU0%81@~byno?bF~a>3C{(NQ%w3Qe$$&dDyciRodrbNwNxdR}byx9C#-d&4fRD|bi|Lv8(ch~-XedXWR SSAKVM^8WyoQmuSQZvX)1?X4^T literal 0 HcmV?d00001 diff --git a/libs/partners/openai/tests/cassettes/test_parsed_pydantic_schema.yaml.gz b/libs/partners/openai/tests/cassettes/test_parsed_pydantic_schema.yaml.gz new file mode 100644 index 0000000000000000000000000000000000000000..13c0b8896decc64a0df2d5181abd8873694b343d GIT binary patch literal 4616 zcmV+j68G&NiwFR`0B~pm|Lt5`kE1%aexF~_^EA?yN&}>__ok(J$fZI^7i4lNZf^mS zKmsH?kWfVPMX#%eTI@Y_psHwV%gH{QYlF-|UT>egF3N zzuSH^+hyPW^S5vCfAegxwWBugcI4mnM(mq{xOXLOkN%cJ@8D)_&k5g_bN$7C?F=$o zcn{>&#&-=haQuO2Oj6(0#kHlUJ|1xB4u63#MXoTY$C^Uq;gKEdG~v0!j{Z5`|kw;}C+hZUwyD5jP%;P8MmT|4G)8U#+GeXt`nDBiqvrF?mBp>}j6 zaSt;^%jV`r8~^S!#@S47Msmc9NE~=T%joelV-Q$KhS?yq3^lcd=FyrAT-14HjN|I; z=n@ZRxDCw42))F!r12IH=t{u>mK4AB2M0&eGWa`iG9o)oTo}9v4ZZoyo$o-mcvX%w z-zNEgKYjbhZ@-OSx~y`qA=Q%YEW)cK&ge_?zfTJ7*I`bLV!W zWF93?!8}?f-W?5-Y&*Zw%-h7*R=JPoyQkR_2D;Md{V7gneiY2(_-VFtynCC*S>)W= zHke%%G23(#-`Q|GbJno5rvoVcy3pBPyLqz9o}A78ZjtPusUOZi*81kG!Z}!#`~BO$ zhm$G%x!?KU>y5q&MiE$nmpr4Fzyw0OF-S}!$1ArVq}ZXu?GXvkc)N+>dC9NvU~#fT7G1|NK(chR zlskhhVhZptW(vm9d{&RGRFMXr8os1zclkclRnH8S)at%u*2;WVrsliabpxgAWav3K zURVsN2%4`4CzDw zA3l2AO2!T=#~PG95DD~lm@tnt=zUi^!i&2%hw-!{~U)xNoUYn0ko1M@g)TDrDU zkHK2)dMl!F#t2Tdk6%IgDiR7nz3N$r$F|BTfIL07uN{ja^QP*-Hrlaej1iM%hg*RG zfx$HK!t`<_OSvEap#NRGnwUDBfyldk2;M!BvkZ?bWW{#z?67%2hK96@)!Iv3X0}9@ z_|NeDU^O6G3InWJo#&XD=8Gd?8;1VFU|m2<{b7*$lZl&f@%|A0N=~b|2KLi@ZE(%XY&dvo!U=0d+@@{>`WU zVj?pOZi{rbJ$B}lwZIB{M@I;d=Mhtf#pFZuIrz6-$Yzk52FMP)#|GKjf*2cuivGre zDQE_^O9y7|h5<}VT@de~#6nx-_E<4}1QAD&8730&?2lrmzNK4;3URDgvFbsPq4U6$ z22I)o4cRXi3zWl&tgQ*J#aWGy)dbYc^5FHvI*3DpnI1s=Xa<}UX#`ck7I~DUw&9i$Y z;l^CsZ-=J-ezZm#>Zmom1c9g=gO}LBiso1@)Wk__@ea_1_%!cIfBM9Jk+Tw@D~6MouUjYlZBLdZf9 zaVBUv{B}nvU3~Gp7?L*QA)PYD-b~NYh7E5JRFR3Vt3$Sa)U1X0I0>57=JK(eK^Z`7 zAql{6s#Jk7f1)0+G1Gu>29vjV(k+)|{&q&Ype}&gxcYkYh}tbCEVvc<+aNG-zlBZ6d zxtx1E-OZwS0g1^A$L4e{Bgvmkkj56;|LIo=#m^kjd&e=Tb} z*JQJuM>GFvvs=z0k}F!6k0J$~8c?svrV6!QYW;lL089P|MO3c%=-3QkV8PEtP~lCC3S^@OHORZr2C<4}L& zhuvOibz9Fvec(gWzzXGt8V*858n1_v>MOn+SJJ+y66p!h(uwy(4>f~~k5^VN4W$bS zW|q>N04*4T{{lrItAU16P&o+Vr@Di`4=!+F8V9#|*qo$3@G%IfvOf%Sbi9_{6^seJ z29>LLvr*K&a48K{=W{aJK{YpTaV|uVA9gO#O;Kt#0+t;#esj?@`pjG$CcPMgnAW@Yl6 z;@?qDhDwm31yq~&ih3tC{aFNYG!RLq2&^Ws^ExXK0CnYrK}x6xRmn$oWBwoZV{Y%r zknRj)yhsn}kjSp`;s#bj!3c?r^J?n}^ho%9Eo}!#7Gt#DFi=;kt{e;jX#EwEI87l$YU7E zf+ML?K8TSlI2J=bijgchmUj6NMzY{Y!0j=Nq{U_s3R$>4X`-F8z-ryZ%1U?IuTx?s z8M4kR(H0}|cL8!-Nqa7Y0@k19*%_-98 z`67mvD@kA_xSl}RhJO^iB^ZTsf2 zy;7$mGNhuv9EVf+*#x*y!k%Y_FHN=EUJs!u#AshVTE4YvudS}Pf;LEnCCC9MvWz9D zuN9-sCR#nu{kZxVe(U3kT7?|KaIr18C_4Ix;OKnSHcvJ;9=FZ1@EP1^o90d1B&pN` zFj2S7k_@hH+GddNg5jZRe$V19LG=pkR`>E=3pI5@94%2PwmcgXo%SPkS?i7ihoHT? zI*Q-g0s?UYi$|=jK|)$i5PyC+D^~JgCykiL=yzfTpxwbh;C>0aTz+-C z^xJ;{72|#hyKH`SyOft0^>4(VGD%5)R{SBdWV>I9KgzWA_3e@#4Y+2TG@R9`>gR2f z$ns40E4KNrT#4!Q3`rTpM{wuqbRWP9?Zl?etrob zS36(b$BUy=7j2U)oJlJmwawEL1^3z}f5|qj#TdHo-HGNh9lie|F3eSb%i5Q=61d$_Zf~< zNA>&4qAW*h+Kp-)Vvdm4Q}xfrq5D((D{)9p**H{^!ISjzvvG)?h&W9z?~FrU-Oq`} zYiar;eqP(?`h=fn^&F@?3Ud5eeqQ=``YAtuK{wZa2_Ij{U*E@9rvChE`1tY5`#9u5 zzlL3ozP?><;J;HAGfCCpOlkA8n8_S_egrewy&$5VD_cSz#muticK9@A0?!1;q`aLL zBF>kJKO;}F@C&8l56QF0gNPB`*kb#LJc$wA*kZdSPjo-&7m;Q5)n&P0Fq7`-Uf%Y- z1bXm$@Od9U*m=d}QtWyDnz4u8*+To~%nFT?xfvb2)wY~P8l*;TPwu18xUb@nsOVdz~GqkK9Qs#E=^-hw5U&@YeabRRMgN%>XJfp`aIVh1do zwZb#2S7*;3l@So%;ec(3d+QoEI<|1pv07fPI9-0|3?DcA5SG$ykDEqQo*~p>CHd() zeEt#$A&)Ow6?cv)87z*!QBN2lhr4*P>DrMdVwYD=HqCkV=&|NNS4Up2BYRTpL`yHr z;sBdxPf~ipH_|xy1vlGR#vlM=*;}19IEowC>&HNe{L6y@xf+;{_@Iyad?Y}S$UtTJcz__0ffjWh4-h0W z6ho2#aqZYEeg>i61g*mm=o=OXyC<&D-)ME@+d(B)r}z3#F7y+_V(-Fc80|ro)&U1* z+D0XT!q$0++3VubyO?vyfq;%`8x3^QTSjtaFYNpDmOee?54H-BQwCkmML4VoSPC)Zx; yL!Y?{PAMyW3y09fuDsoR$FjZWZou-(U-$lU_ql(G%i51$Nd7;Su*H6VQvd*`4k)+) literal 0 HcmV?d00001 diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 4b02676e2a8..0417bbc32ed 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -473,6 +473,7 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None: name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="foo", + type="tool_call", ) ], ), @@ -494,6 +495,7 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None: name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="bar", + type="tool_call", ) ], ), diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index 0e23d0e3f06..527eece1241 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -52,9 +52,11 @@ def _check_response(response: Optional[BaseMessage]) -> None: assert response.response_metadata["service_tier"] +@pytest.mark.default_cassette("test_web_search.yaml.gz") @pytest.mark.vcr -def test_web_search() -> None: - llm = ChatOpenAI(model=MODEL_NAME, output_version="responses/v1") +@pytest.mark.parametrize("output_version", ["responses/v1", "v1"]) +def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None: + llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version) first_response = llm.invoke( "What was a positive news story from today?", tools=[{"type": "web_search_preview"}], @@ -110,7 +112,10 @@ def test_web_search() -> None: for msg in [first_response, full, response]: assert isinstance(msg, AIMessage) block_types = [block["type"] for block in msg.content] # type: ignore[index] - assert block_types == ["web_search_call", "text"] + if output_version == "responses/v1": + assert block_types == ["web_search_call", "text"] + else: + assert block_types == ["web_search_call", "web_search_result", "text"] @pytest.mark.flaky(retries=3, delay=1) @@ -141,13 +146,15 @@ async def test_web_search_async() -> None: assert tool_output["type"] == "web_search_call" -@pytest.mark.flaky(retries=3, delay=1) -def test_function_calling() -> None: +@pytest.mark.default_cassette("test_function_calling.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -> None: def multiply(x: int, y: int) -> int: """return x * y""" return x * y - llm = ChatOpenAI(model=MODEL_NAME) + llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version) bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}]) ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4")) assert len(ai_msg.tool_calls) == 1 @@ -174,8 +181,15 @@ class FooDict(TypedDict): response: str -def test_parsed_pydantic_schema() -> None: - llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) +@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_parsed_pydantic_schema( + output_version: Literal["v0", "responses/v1", "v1"], +) -> None: + llm = ChatOpenAI( + model=MODEL_NAME, use_responses_api=True, output_version=output_version + ) response = llm.invoke("how are ya", response_format=Foo) parsed = Foo(**json.loads(response.text())) assert parsed == response.additional_kwargs["parsed"] @@ -297,8 +311,8 @@ def test_function_calling_and_structured_output() -> None: @pytest.mark.default_cassette("test_reasoning.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["v0", "responses/v1"]) -def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None: +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None: llm = ChatOpenAI( model="o4-mini", use_responses_api=True, output_version=output_version ) @@ -358,27 +372,32 @@ def test_computer_calls() -> None: def test_file_search() -> None: pytest.skip() # TODO: set up infra - llm = ChatOpenAI(model=MODEL_NAME) + llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) tool = { "type": "file_search", "vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]], } - response = llm.invoke("What is deep research by OpenAI?", tools=[tool]) + + input_message = {"role": "user", "content": "What is deep research by OpenAI?"} + response = llm.invoke([input_message], tools=[tool]) _check_response(response) full: Optional[BaseMessageChunk] = None - for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]): + for chunk in llm.stream([input_message], tools=[tool]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) _check_response(full) + next_message = {"role": "user", "content": "Thank you."} + _ = llm.invoke([input_message, full, next_message]) + @pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz") @pytest.mark.vcr -@pytest.mark.parametrize("output_version", ["v0", "responses/v1"]) +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) def test_stream_reasoning_summary( - output_version: Literal["v0", "responses/v1"], + output_version: Literal["v0", "responses/v1", "v1"], ) -> None: llm = ChatOpenAI( model="o4-mini", @@ -398,20 +417,39 @@ def test_stream_reasoning_summary( if output_version == "v0": reasoning = response_1.additional_kwargs["reasoning"] assert set(reasoning.keys()) == {"id", "type", "summary"} - else: + summary = reasoning["summary"] + assert isinstance(summary, list) + for block in summary: + assert isinstance(block, dict) + assert isinstance(block["type"], str) + assert isinstance(block["text"], str) + assert block["text"] + elif output_version == "responses/v1": reasoning = next( block for block in response_1.content if block["type"] == "reasoning" # type: ignore[index] ) assert set(reasoning.keys()) == {"id", "type", "summary", "index"} - summary = reasoning["summary"] - assert isinstance(summary, list) - for block in summary: - assert isinstance(block, dict) - assert isinstance(block["type"], str) - assert isinstance(block["text"], str) - assert block["text"] + summary = reasoning["summary"] + assert isinstance(summary, list) + for block in summary: + assert isinstance(block, dict) + assert isinstance(block["type"], str) + assert isinstance(block["text"], str) + assert block["text"] + else: + # v1 + total_reasoning_blocks = 0 + for block in response_1.content: + if block["type"] == "reasoning": + total_reasoning_blocks += 1 + assert isinstance(block["id"], str) and block["id"].startswith("rs_") + assert isinstance(block["reasoning"], str) + assert isinstance(block["index"], int) + assert ( + total_reasoning_blocks > 1 + ) # This query typically generates multiple reasoning blocks # Check we can pass back summaries message_2 = {"role": "user", "content": "Thank you."} @@ -419,9 +457,13 @@ def test_stream_reasoning_summary( assert isinstance(response_2, AIMessage) +@pytest.mark.default_cassette("test_code_interpreter.yaml.gz") @pytest.mark.vcr -def test_code_interpreter() -> None: - llm = ChatOpenAI(model="o4-mini", use_responses_api=True) +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -> None: + llm = ChatOpenAI( + model="o4-mini", use_responses_api=True, output_version=output_version + ) llm_with_tools = llm.bind_tools( [{"type": "code_interpreter", "container": {"type": "auto"}}] ) @@ -430,15 +472,38 @@ def test_code_interpreter() -> None: "content": "Write and run code to answer the question: what is 3^3?", } response = llm_with_tools.invoke([input_message]) + assert isinstance(response, AIMessage) _check_response(response) - tool_outputs = response.additional_kwargs["tool_outputs"] - assert tool_outputs - assert any(output["type"] == "code_interpreter_call" for output in tool_outputs) + if output_version == "v0": + tool_outputs = [ + item + for item in response.additional_kwargs["tool_outputs"] + if item["type"] == "code_interpreter_call" + ] + elif output_version == "responses/v1": + tool_outputs = [ + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_call" + ] + else: + # v1 + tool_outputs = [ + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_call" + ] + code_interpreter_result = next( + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_result" + ) + assert tool_outputs + assert code_interpreter_result + assert len(tool_outputs) == 1 # Test streaming # Use same container - tool_outputs = response.additional_kwargs["tool_outputs"] - assert len(tool_outputs) == 1 container_id = tool_outputs[0]["container_id"] llm_with_tools = llm.bind_tools( [{"type": "code_interpreter", "container": container_id}] @@ -449,9 +514,32 @@ def test_code_interpreter() -> None: assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) - tool_outputs = full.additional_kwargs["tool_outputs"] + if output_version == "v0": + tool_outputs = [ + item + for item in response.additional_kwargs["tool_outputs"] + if item["type"] == "code_interpreter_call" + ] + elif output_version == "responses/v1": + tool_outputs = [ + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_call" + ] + else: + code_interpreter_call = next( + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_call" + ) + code_interpreter_result = next( + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_result" + ) + assert code_interpreter_call + assert code_interpreter_result assert tool_outputs - assert any(output["type"] == "code_interpreter_call" for output in tool_outputs) # Test we can pass back in next_message = {"role": "user", "content": "Please add more comments to the code."} @@ -546,10 +634,14 @@ def test_mcp_builtin_zdr() -> None: _ = llm_with_tools.invoke([input_message, full, approval_message]) -@pytest.mark.vcr() -def test_image_generation_streaming() -> None: +@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_image_generation_streaming(output_version: str) -> None: """Test image generation streaming.""" - llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True) + llm = ChatOpenAI( + model="gpt-4.1", use_responses_api=True, output_version=output_version + ) tool = { "type": "image_generation", # For testing purposes let's keep the quality low, so the test runs faster. @@ -596,15 +688,37 @@ def test_image_generation_streaming() -> None: # At the moment, the streaming API does not pick up annotations fully. # So the following check is commented out. # _check_response(complete_ai_message) - tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0] - assert set(tool_output.keys()).issubset(expected_keys) + if output_version == "v0": + assert complete_ai_message.additional_kwargs["tool_outputs"] + tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0] + assert set(tool_output.keys()).issubset(expected_keys) + elif output_version == "responses/v1": + tool_output = next( + block + for block in complete_ai_message.content + if isinstance(block, dict) and block["type"] == "image_generation_call" + ) + assert set(tool_output.keys()).issubset(expected_keys) + else: + # v1 + standard_keys = {"type", "base64", "id", "status", "index"} + tool_output = next( + block + for block in complete_ai_message.content + if isinstance(block, dict) and block["type"] == "image" + ) + assert set(standard_keys).issubset(tool_output.keys()) -@pytest.mark.vcr() -def test_image_generation_multi_turn() -> None: +@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_image_generation_multi_turn(output_version: str) -> None: """Test multi-turn editing of image generation by passing in history.""" # Test multi-turn - llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True) + llm = ChatOpenAI( + model="gpt-4.1", use_responses_api=True, output_version=output_version + ) # Test invocation tool = { "type": "image_generation", @@ -620,10 +734,41 @@ def test_image_generation_multi_turn() -> None: {"role": "user", "content": "Draw a random short word in green font."} ] ai_message = llm_with_tools.invoke(chat_history) + assert isinstance(ai_message, AIMessage) _check_response(ai_message) - tool_output = ai_message.additional_kwargs["tool_outputs"][0] - # Example tool output for an image + expected_keys = { + "id", + "background", + "output_format", + "quality", + "result", + "revised_prompt", + "size", + "status", + "type", + } + + if output_version == "v0": + tool_output = ai_message.additional_kwargs["tool_outputs"][0] + assert set(tool_output.keys()).issubset(expected_keys) + elif output_version == "responses/v1": + tool_output = next( + block + for block in ai_message.content + if isinstance(block, dict) and block["type"] == "image_generation_call" + ) + assert set(tool_output.keys()).issubset(expected_keys) + else: + standard_keys = {"type", "base64", "id", "status"} + tool_output = next( + block + for block in ai_message.content + if isinstance(block, dict) and block["type"] == "image" + ) + assert set(standard_keys).issubset(tool_output.keys()) + + # Example tool output for an image (v0) # { # "background": "opaque", # "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8", @@ -639,20 +784,6 @@ def test_image_generation_multi_turn() -> None: # "result": # base64 encode image data # } - expected_keys = { - "id", - "background", - "output_format", - "quality", - "result", - "revised_prompt", - "size", - "status", - "type", - } - - assert set(tool_output.keys()).issubset(expected_keys) - chat_history.extend( [ # AI message with tool output @@ -669,6 +800,24 @@ def test_image_generation_multi_turn() -> None: ) ai_message2 = llm_with_tools.invoke(chat_history) + assert isinstance(ai_message2, AIMessage) _check_response(ai_message2) - tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0] - assert set(tool_output2.keys()).issubset(expected_keys) + + if output_version == "v0": + tool_output = ai_message2.additional_kwargs["tool_outputs"][0] + assert set(tool_output.keys()).issubset(expected_keys) + elif output_version == "responses/v1": + tool_output = next( + block + for block in ai_message2.content + if isinstance(block, dict) and block["type"] == "image_generation_call" + ) + assert set(tool_output.keys()).issubset(expected_keys) + else: + standard_keys = {"type", "base64", "id", "status"} + tool_output = next( + block + for block in ai_message2.content + if isinstance(block, dict) and block["type"] == "image" + ) + assert set(standard_keys).issubset(tool_output.keys()) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index c4176711482..6a1f94c60db 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -51,7 +51,11 @@ from langchain_openai import ChatOpenAI from langchain_openai.chat_models._compat import ( _FUNCTION_CALL_IDS_MAP_KEY, _convert_from_v03_ai_message, + _convert_from_v1_to_chat_completions, + _convert_from_v1_to_responses, _convert_to_v03_ai_message, + _convert_to_v1_from_chat_completions, + _convert_to_v1_from_responses, ) from langchain_openai.chat_models.base import ( _construct_lc_result_from_responses_api, @@ -2297,7 +2301,7 @@ def test_mcp_tracing() -> None: assert payload["tools"][0]["headers"]["Authorization"] == "Bearer PLACEHOLDER" -def test_compat() -> None: +def test_compat_responses_v1() -> None: # Check compatibility with v0.3 message format message_v03 = AIMessage( content=[ @@ -2358,6 +2362,411 @@ def test_compat() -> None: assert message_v03_output is not message_v03 +@pytest.mark.parametrize( + "message_v1, expected", + [ + ( + AIMessage( + [ + {"type": "reasoning", "reasoning": "Reasoning text"}, + {"type": "tool_call", "id": "call_123"}, + { + "type": "text", + "text": "Hello, world!", + "annotations": [ + {"type": "url_citation", "url": "https://example.com"} + ], + }, + ], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ], + id="chatcmpl-123", + response_metadata={"foo": "bar"}, + ), + AIMessage( + [{"type": "text", "text": "Hello, world!"}], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ], + id="chatcmpl-123", + response_metadata={"foo": "bar"}, + ), + ) + ], +) +def test_convert_from_v1_to_chat_completions( + message_v1: AIMessage, expected: AIMessage +) -> None: + result = _convert_from_v1_to_chat_completions(message_v1) + assert result == expected + + # Check no mutation + assert message_v1 != result + + +@pytest.mark.parametrize( + "message_chat_completions, expected", + [ + ( + AIMessage( + "Hello, world!", id="chatcmpl-123", response_metadata={"foo": "bar"} + ), + AIMessage( + [{"type": "text", "text": "Hello, world!"}], + id="chatcmpl-123", + response_metadata={"foo": "bar"}, + ), + ), + ( + AIMessage( + [{"type": "text", "text": "Hello, world!"}], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ], + id="chatcmpl-123", + response_metadata={"foo": "bar"}, + ), + AIMessage( + [ + {"type": "text", "text": "Hello, world!"}, + {"type": "tool_call", "id": "call_123"}, + ], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ], + id="chatcmpl-123", + response_metadata={"foo": "bar"}, + ), + ), + ( + AIMessage( + "", + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ], + id="chatcmpl-123", + response_metadata={"foo": "bar"}, + additional_kwargs={"tool_calls": [{"foo": "bar"}]}, + ), + AIMessage( + [{"type": "tool_call", "id": "call_123"}], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ], + id="chatcmpl-123", + response_metadata={"foo": "bar"}, + ), + ), + ], +) +def test_convert_to_v1_from_chat_completions( + message_chat_completions: AIMessage, expected: AIMessage +) -> None: + result = _convert_to_v1_from_chat_completions(message_chat_completions) + assert result == expected + + +@pytest.mark.parametrize( + "message_v1, expected", + [ + ( + AIMessage( + [ + {"type": "reasoning", "id": "abc123"}, + {"type": "reasoning", "id": "abc234", "reasoning": "foo "}, + {"type": "reasoning", "id": "abc234", "reasoning": "bar"}, + {"type": "tool_call", "id": "call_123"}, + { + "type": "tool_call", + "id": "call_234", + "name": "get_weather_2", + "arguments": '{"location": "New York"}', + "item_id": "fc_123", + }, + {"type": "text", "text": "Hello "}, + { + "type": "text", + "text": "world", + "annotations": [ + {"type": "url_citation", "url": "https://example.com"}, + { + "type": "document_citation", + "title": "my doc", + "index": 1, + "file_id": "file_123", + }, + { + "type": "non_standard_annotation", + "value": {"bar": "baz"}, + }, + ], + }, + {"type": "image", "base64": "...", "id": "img_123"}, + { + "type": "non_standard", + "value": {"type": "something_else", "foo": "bar"}, + }, + ], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + }, + { + # Make values different to check we pull from content when + # available + "type": "tool_call", + "id": "call_234", + "name": "get_weather_3", + "args": {"location": "Boston"}, + }, + ], + id="resp123", + response_metadata={"foo": "bar"}, + ), + AIMessage( + [ + {"type": "reasoning", "id": "abc123", "summary": []}, + { + "type": "reasoning", + "id": "abc234", + "summary": [ + {"type": "summary_text", "text": "foo "}, + {"type": "summary_text", "text": "bar"}, + ], + }, + { + "type": "function_call", + "call_id": "call_123", + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + }, + { + "type": "function_call", + "call_id": "call_234", + "name": "get_weather_2", + "arguments": '{"location": "New York"}', + "id": "fc_123", + }, + {"type": "text", "text": "Hello "}, + { + "type": "text", + "text": "world", + "annotations": [ + {"type": "url_citation", "url": "https://example.com"}, + { + "type": "file_citation", + "filename": "my doc", + "index": 1, + "file_id": "file_123", + }, + {"bar": "baz"}, + ], + }, + {"type": "image_generation_call", "id": "img_123", "result": "..."}, + {"type": "something_else", "foo": "bar"}, + ], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + }, + { + # Make values different to check we pull from content when + # available + "type": "tool_call", + "id": "call_234", + "name": "get_weather_3", + "args": {"location": "Boston"}, + }, + ], + id="resp123", + response_metadata={"foo": "bar"}, + ), + ) + ], +) +def test_convert_from_v1_to_responses( + message_v1: AIMessage, expected: AIMessage +) -> None: + result = _convert_from_v1_to_responses(message_v1) + assert result == expected + + # Check no mutation + assert message_v1 != result + + +@pytest.mark.parametrize( + "message_responses, expected", + [ + ( + AIMessage( + [ + {"type": "reasoning", "id": "abc123"}, + { + "type": "reasoning", + "id": "abc234", + "summary": [ + {"type": "summary_text", "text": "foo "}, + {"type": "summary_text", "text": "bar"}, + ], + }, + { + "type": "function_call", + "call_id": "call_123", + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + }, + { + "type": "function_call", + "call_id": "call_234", + "name": "get_weather_2", + "arguments": '{"location": "New York"}', + "id": "fc_123", + }, + {"type": "text", "text": "Hello "}, + { + "type": "text", + "text": "world", + "annotations": [ + {"type": "url_citation", "url": "https://example.com"}, + { + "type": "file_citation", + "filename": "my doc", + "index": 1, + "file_id": "file_123", + }, + {"bar": "baz"}, + ], + }, + {"type": "image_generation_call", "id": "img_123", "result": "..."}, + {"type": "something_else", "foo": "bar"}, + ], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + }, + { + # Make values different to check we pull from content when + # available + "type": "tool_call", + "id": "call_234", + "name": "get_weather_3", + "args": {"location": "Boston"}, + }, + ], + id="resp123", + response_metadata={"foo": "bar"}, + ), + AIMessage( + [ + {"type": "reasoning", "id": "abc123"}, + {"type": "reasoning", "id": "abc234", "reasoning": "foo "}, + {"type": "reasoning", "id": "abc234", "reasoning": "bar"}, + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + }, + { + "type": "tool_call", + "id": "call_234", + "name": "get_weather_2", + "arguments": '{"location": "New York"}', + "item_id": "fc_123", + }, + {"type": "text", "text": "Hello "}, + { + "type": "text", + "text": "world", + "annotations": [ + {"type": "url_citation", "url": "https://example.com"}, + { + "type": "document_citation", + "title": "my doc", + "index": 1, + "file_id": "file_123", + }, + { + "type": "non_standard_annotation", + "value": {"bar": "baz"}, + }, + ], + }, + {"type": "image", "base64": "...", "id": "img_123"}, + { + "type": "non_standard", + "value": {"type": "something_else", "foo": "bar"}, + }, + ], + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + }, + { + # Make values different to check we pull from content when + # available + "type": "tool_call", + "id": "call_234", + "name": "get_weather_3", + "args": {"location": "Boston"}, + }, + ], + id="resp123", + response_metadata={"foo": "bar"}, + ), + ) + ], +) +def test_convert_to_v1_from_responses( + message_responses: AIMessage, expected: AIMessage +) -> None: + result = _convert_to_v1_from_responses(message_responses) + assert result == expected + + def test_get_last_messages() -> None: messages: list[BaseMessage] = [HumanMessage("Hello")] last_messages, previous_response_id = _get_last_messages(messages) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py index 370adcd1f1a..bee9e6fe095 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py @@ -1,6 +1,7 @@ from typing import Any, Optional from unittest.mock import MagicMock, patch +import pytest from langchain_core.messages import AIMessageChunk, BaseMessageChunk from openai.types.responses import ( ResponseCompletedEvent, @@ -610,8 +611,97 @@ def _strip_none(obj: Any) -> Any: return obj -def test_responses_stream() -> None: - llm = ChatOpenAI(model="o4-mini", output_version="responses/v1") +@pytest.mark.parametrize( + "output_version, expected_content", + [ + ( + "responses/v1", + [ + { + "id": "rs_123", + "summary": [ + { + "index": 0, + "type": "summary_text", + "text": "reasoning block one", + }, + { + "index": 1, + "type": "summary_text", + "text": "another reasoning block", + }, + ], + "type": "reasoning", + "index": 0, + }, + {"type": "text", "text": "text block one", "index": 1, "id": "msg_123"}, + { + "type": "text", + "text": "another text block", + "index": 2, + "id": "msg_123", + }, + { + "id": "rs_234", + "summary": [ + {"index": 0, "type": "summary_text", "text": "more reasoning"}, + { + "index": 1, + "type": "summary_text", + "text": "still more reasoning", + }, + ], + "type": "reasoning", + "index": 3, + }, + {"type": "text", "text": "more", "index": 4, "id": "msg_234"}, + {"type": "text", "text": "text", "index": 5, "id": "msg_234"}, + ], + ), + ( + "v1", + [ + { + "type": "reasoning", + "reasoning": "reasoning block one", + "id": "rs_123", + "index": 0, + }, + { + "type": "reasoning", + "reasoning": "another reasoning block", + "id": "rs_123", + "index": 1, + }, + {"type": "text", "text": "text block one", "index": 2, "id": "msg_123"}, + { + "type": "text", + "text": "another text block", + "index": 3, + "id": "msg_123", + }, + { + "type": "reasoning", + "reasoning": "more reasoning", + "id": "rs_234", + "index": 4, + }, + { + "type": "reasoning", + "reasoning": "still more reasoning", + "id": "rs_234", + "index": 5, + }, + {"type": "text", "text": "more", "index": 6, "id": "msg_234"}, + {"type": "text", "text": "text", "index": 7, "id": "msg_234"}, + ], + ), + ], +) +def test_responses_stream(output_version: str, expected_content: list[dict]) -> None: + llm = ChatOpenAI( + model="o4-mini", use_responses_api=True, output_version=output_version + ) mock_client = MagicMock() def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: @@ -620,36 +710,14 @@ def test_responses_stream() -> None: mock_client.responses.create = mock_create full: Optional[BaseMessageChunk] = None + chunks = [] with patch.object(llm, "root_client", mock_client): for chunk in llm.stream("test"): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) + chunks.append(chunk) - expected_content = [ - { - "id": "rs_123", - "summary": [ - {"index": 0, "type": "summary_text", "text": "reasoning block one"}, - {"index": 1, "type": "summary_text", "text": "another reasoning block"}, - ], - "type": "reasoning", - "index": 0, - }, - {"type": "text", "text": "text block one", "index": 1, "id": "msg_123"}, - {"type": "text", "text": "another text block", "index": 2, "id": "msg_123"}, - { - "id": "rs_234", - "summary": [ - {"index": 0, "type": "summary_text", "text": "more reasoning"}, - {"index": 1, "type": "summary_text", "text": "still more reasoning"}, - ], - "type": "reasoning", - "index": 3, - }, - {"type": "text", "text": "more", "index": 4, "id": "msg_234"}, - {"type": "text", "text": "text", "index": 5, "id": "msg_234"}, - ] + assert isinstance(full, AIMessageChunk) assert full.content == expected_content assert full.additional_kwargs == {} assert full.id == "resp_123"