mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 16:20:16 +00:00
Compare commits
56 Commits
sr/fix-inj
...
mdrxy/invo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e98fc34203 | ||
|
|
43b9d3d904 | ||
|
|
27d81cf3d9 | ||
|
|
313ed7b401 | ||
|
|
f0f1e28473 | ||
|
|
0e6c172893 | ||
|
|
8ee0cbba3c | ||
|
|
4790c7265a | ||
|
|
aeea0e3ff8 | ||
|
|
aca7c1fe6a | ||
|
|
2375c3a4d0 | ||
|
|
0199b56bda | ||
|
|
00345c4de9 | ||
|
|
7f9727ee08 | ||
|
|
08cd5bb9b4 | ||
|
|
987031f86c | ||
|
|
7a8c6398a4 | ||
|
|
f691dc348f | ||
|
|
86252d2ae6 | ||
|
|
8bd2403518 | ||
|
|
4dd9110424 | ||
|
|
8fc1973bbf | ||
|
|
a3b20b0ef5 | ||
|
|
301a425151 | ||
|
|
3db8c60112 | ||
|
|
8d110599cb | ||
|
|
c9e847fcb8 | ||
|
|
174e685139 | ||
|
|
601fa7d672 | ||
|
|
7e39cd18c5 | ||
|
|
9721684501 | ||
|
|
a4e135b508 | ||
|
|
d111965448 | ||
|
|
624300cefa | ||
|
|
0aac20e655 | ||
|
|
2c9cfa8817 | ||
|
|
153db48c92 | ||
|
|
527d62de3a | ||
|
|
80c595d7da | ||
|
|
803d19f31e | ||
|
|
2f604eb9a0 | ||
|
|
3ae37b5987 | ||
|
|
0c7294f608 | ||
|
|
5c961ca4f6 | ||
|
|
c0e4361192 | ||
|
|
c1d65a7d7f | ||
|
|
3ae7535f42 | ||
|
|
6eaa17205c | ||
|
|
98d5f469e3 | ||
|
|
0ddab9ff20 | ||
|
|
91b2bb3417 | ||
|
|
8426db47f1 | ||
|
|
1b9ec25755 | ||
|
|
f8244b9108 | ||
|
|
54a3c5f85c | ||
|
|
7090060b68 |
@@ -36,16 +36,17 @@ from langchain_core.language_models.base import (
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
convert_to_openai_data_block,
|
||||
convert_to_openai_image_block,
|
||||
is_data_content_block,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.messages.ai import _LC_ID_PREFIX
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@@ -65,6 +66,7 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||
from langchain_core.utils.utils import LC_ID_PREFIX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
@@ -130,6 +132,19 @@ def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
message_to_trace.content[idx] = ( # type: ignore[index] # mypy confused by .model_copy
|
||||
convert_to_openai_image_block(block)
|
||||
)
|
||||
elif (
|
||||
block.get("type") == "file"
|
||||
and is_data_content_block(block)
|
||||
and "base64" in block
|
||||
):
|
||||
if message_to_trace is message:
|
||||
# Shallow copy
|
||||
message_to_trace = message.model_copy()
|
||||
message_to_trace.content = list(message_to_trace.content)
|
||||
|
||||
message_to_trace.content[idx] = convert_to_openai_data_block( # type: ignore[index]
|
||||
block
|
||||
)
|
||||
elif len(block) == 1 and "type" not in block:
|
||||
# Tracing assumes all content blocks have a "type" key. Here
|
||||
# we add this key if it is missing, and there's an obvious
|
||||
@@ -320,6 +335,21 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
"""
|
||||
|
||||
output_version: str = "v0"
|
||||
"""Version of ``AIMessage`` output format to use.
|
||||
|
||||
This field is used to roll-out new output formats for chat model ``AIMessage``s
|
||||
in a backwards-compatible way.
|
||||
|
||||
``'v1'`` standardizes output format using a list of typed ContentBlock dicts. We
|
||||
recommend this for new applications.
|
||||
|
||||
All chat models currently support the default of ``'v0'``.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
"""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
@@ -380,9 +410,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
output_version: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Invoke the chat model.
|
||||
|
||||
Args:
|
||||
input: The input to the chat model.
|
||||
config: The config to use for this run.
|
||||
stop: Stop words to use when generating.
|
||||
output_version: Override the model's ``output_version`` for this invocation.
|
||||
If None, uses the model's configured ``output_version``.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The model's response message.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
|
||||
effective_output_version = (
|
||||
output_version if output_version is not None else self.output_version
|
||||
)
|
||||
kwargs["_output_version"] = effective_output_version
|
||||
|
||||
return cast(
|
||||
"ChatGeneration",
|
||||
self.generate_prompt(
|
||||
@@ -404,9 +454,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
output_version: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Asynchronously invoke the chat model.
|
||||
|
||||
Args:
|
||||
input: The input to the chat model.
|
||||
config: The config to use for this run.
|
||||
stop: Stop words to use when generating.
|
||||
output_version: Override the model's ``output_version`` for this invocation.
|
||||
If None, uses the model's configured ``output_version``.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The model's response message.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
|
||||
effective_output_version = (
|
||||
output_version if output_version is not None else self.output_version
|
||||
)
|
||||
kwargs["_output_version"] = effective_output_version
|
||||
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
@@ -461,13 +531,38 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
output_version: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
"""Stream responses from the chat model.
|
||||
|
||||
Args:
|
||||
input: The input to the chat model.
|
||||
config: The config to use for this run.
|
||||
stop: Stop words to use when generating.
|
||||
output_version: Override the model's ``output_version`` for this invocation.
|
||||
If None, uses the model's configured ``output_version``.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Iterator of message chunks.
|
||||
"""
|
||||
effective_output_version = (
|
||||
output_version if output_version is not None else self.output_version
|
||||
)
|
||||
kwargs["_output_version"] = effective_output_version
|
||||
|
||||
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
"BaseMessageChunk",
|
||||
self.invoke(input, config=config, stop=stop, **kwargs),
|
||||
self.invoke(
|
||||
input,
|
||||
config=config,
|
||||
stop=stop,
|
||||
output_version=effective_output_version,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
else:
|
||||
config = ensure_config(config)
|
||||
@@ -511,11 +606,19 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
for chunk in self._stream(input_messages, stop=stop, **kwargs):
|
||||
run_id = "-".join((LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
for chunk in self._stream(
|
||||
input_messages,
|
||||
stop=stop,
|
||||
output_version=kwargs["_output_version"],
|
||||
**kwargs,
|
||||
):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = run_id
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
output_version = kwargs["_output_version"]
|
||||
if isinstance(chunk.message, (AIMessage, AIMessageChunk)):
|
||||
chunk.message.additional_kwargs["output_version"] = output_version
|
||||
run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
@@ -552,13 +655,38 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
output_version: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
"""Asynchronously stream responses from the chat model.
|
||||
|
||||
Args:
|
||||
input: The input to the chat model.
|
||||
config: The config to use for this run.
|
||||
stop: Stop words to use when generating.
|
||||
output_version: Override the model's ``output_version`` for this invocation.
|
||||
If None, uses the model's configured ``output_version``.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Async iterator of message chunks.
|
||||
"""
|
||||
effective_output_version = (
|
||||
output_version if output_version is not None else self.output_version
|
||||
)
|
||||
kwargs["_output_version"] = effective_output_version
|
||||
|
||||
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
||||
# No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield cast(
|
||||
"BaseMessageChunk",
|
||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||
await self.ainvoke(
|
||||
input,
|
||||
config=config,
|
||||
stop=stop,
|
||||
output_version=effective_output_version,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
@@ -604,15 +732,19 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
run_id = "-".join((LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
async for chunk in self._astream(
|
||||
input_messages,
|
||||
stop=stop,
|
||||
output_version=kwargs["_output_version"],
|
||||
**kwargs,
|
||||
):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = run_id
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
output_version = kwargs["_output_version"]
|
||||
if isinstance(chunk.message, (AIMessage, AIMessageChunk)):
|
||||
chunk.message.additional_kwargs["output_version"] = output_version
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
@@ -622,7 +754,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
generations_with_error_metadata = _generate_response_from_error(e)
|
||||
chat_generation_chunk = merge_chat_generation_chunks(chunks)
|
||||
if chat_generation_chunk:
|
||||
generations = [[chat_generation_chunk], generations_with_error_metadata]
|
||||
generations = [
|
||||
[chat_generation_chunk],
|
||||
generations_with_error_metadata,
|
||||
]
|
||||
else:
|
||||
generations = [generations_with_error_metadata]
|
||||
await run_manager.on_llm_error(
|
||||
@@ -1058,6 +1193,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
output_version = kwargs.pop("_output_version", self.output_version)
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||
if self._should_stream(
|
||||
@@ -1066,11 +1203,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
**kwargs,
|
||||
):
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, output_version=output_version, **kwargs
|
||||
):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
if isinstance(chunk.message, (AIMessage, AIMessageChunk)):
|
||||
chunk.message.additional_kwargs["output_version"] = output_version
|
||||
if run_manager:
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
@@ -1078,18 +1219,26 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
result = generate_from_stream(iter(chunks))
|
||||
elif inspect.signature(self._generate).parameters.get("run_manager"):
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
output_version=output_version,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
result = self._generate(
|
||||
messages, stop=stop, output_version=output_version, **kwargs
|
||||
)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
if isinstance(generation.message, (AIMessage, AIMessageChunk)):
|
||||
generation.message.additional_kwargs["output_version"] = output_version
|
||||
if len(result.generations) == 1 and result.llm_output is not None:
|
||||
result.generations[0].message.response_metadata = {
|
||||
**result.llm_output,
|
||||
@@ -1131,6 +1280,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
output_version = kwargs.pop("_output_version", self.output_version)
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||
if self._should_stream(
|
||||
@@ -1139,11 +1290,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
**kwargs,
|
||||
):
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||
async for chunk in self._astream(
|
||||
messages, stop=stop, output_version=output_version, **kwargs
|
||||
):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
if isinstance(chunk.message, (AIMessage, AIMessageChunk)):
|
||||
chunk.message.additional_kwargs["output_version"] = output_version
|
||||
if run_manager:
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
@@ -1151,18 +1306,26 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
result = generate_from_stream(iter(chunks))
|
||||
elif inspect.signature(self._agenerate).parameters.get("run_manager"):
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
output_version=output_version,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, output_version=output_version, **kwargs
|
||||
)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
if isinstance(generation.message, (AIMessage, AIMessageChunk)):
|
||||
generation.message.additional_kwargs["output_version"] = output_version
|
||||
if len(result.generations) == 1 and result.llm_output is not None:
|
||||
result.generations[0].message.response_metadata = {
|
||||
**result.llm_output,
|
||||
@@ -1178,15 +1341,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call."""
|
||||
# Concrete implementations should override this method and use the same params
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call."""
|
||||
@@ -1196,6 +1364,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
messages,
|
||||
stop,
|
||||
run_manager.get_sync() if run_manager else None,
|
||||
output_version=output_version,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1204,6 +1373,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError
|
||||
@@ -1213,6 +1384,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
iterator = await run_in_executor(
|
||||
@@ -1221,6 +1394,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
messages,
|
||||
stop,
|
||||
run_manager.get_sync() if run_manager else None,
|
||||
output_version=output_version,
|
||||
**kwargs,
|
||||
)
|
||||
done = object()
|
||||
@@ -1567,6 +1741,9 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
# For backward compatibility
|
||||
output_version: str = "v0", # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
@@ -1589,6 +1766,8 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return await run_in_executor(
|
||||
@@ -1597,6 +1776,7 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
output_version=output_version,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core._import_utils import import_attr
|
||||
from langchain_core.utils.utils import LC_AUTO_PREFIX, LC_ID_PREFIX, ensure_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages.ai import (
|
||||
@@ -32,10 +33,32 @@ if TYPE_CHECKING:
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||
from langchain_core.messages.content_blocks import (
|
||||
from langchain_core.messages.content import (
|
||||
Annotation,
|
||||
AudioContentBlock,
|
||||
Citation,
|
||||
CodeInterpreterCall,
|
||||
CodeInterpreterOutput,
|
||||
CodeInterpreterResult,
|
||||
ContentBlock,
|
||||
DataContentBlock,
|
||||
FileContentBlock,
|
||||
ImageContentBlock,
|
||||
NonStandardAnnotation,
|
||||
NonStandardContentBlock,
|
||||
PlainTextContentBlock,
|
||||
ReasoningContentBlock,
|
||||
TextContentBlock,
|
||||
VideoContentBlock,
|
||||
WebSearchCall,
|
||||
WebSearchResult,
|
||||
convert_to_openai_data_block,
|
||||
convert_to_openai_image_block,
|
||||
is_data_content_block,
|
||||
is_reasoning_block,
|
||||
is_text_block,
|
||||
is_tool_call_block,
|
||||
is_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||
@@ -63,34 +86,59 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"LC_AUTO_PREFIX",
|
||||
"LC_ID_PREFIX",
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"Annotation",
|
||||
"AnyMessage",
|
||||
"AudioContentBlock",
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"Citation",
|
||||
"CodeInterpreterCall",
|
||||
"CodeInterpreterOutput",
|
||||
"CodeInterpreterResult",
|
||||
"ContentBlock",
|
||||
"DataContentBlock",
|
||||
"FileContentBlock",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"ImageContentBlock",
|
||||
"InvalidToolCall",
|
||||
"MessageLikeRepresentation",
|
||||
"NonStandardAnnotation",
|
||||
"NonStandardContentBlock",
|
||||
"PlainTextContentBlock",
|
||||
"ReasoningContentBlock",
|
||||
"RemoveMessage",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"TextContentBlock",
|
||||
"ToolCall",
|
||||
"ToolCallChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"VideoContentBlock",
|
||||
"WebSearchCall",
|
||||
"WebSearchResult",
|
||||
"_message_from_dict",
|
||||
"convert_to_messages",
|
||||
"convert_to_openai_data_block",
|
||||
"convert_to_openai_image_block",
|
||||
"convert_to_openai_messages",
|
||||
"ensure_id",
|
||||
"filter_messages",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
"is_reasoning_block",
|
||||
"is_text_block",
|
||||
"is_tool_call_block",
|
||||
"is_tool_call_chunk",
|
||||
"merge_content",
|
||||
"merge_message_runs",
|
||||
"message_chunk_to_message",
|
||||
@@ -103,35 +151,57 @@ __all__ = (
|
||||
_dynamic_imports = {
|
||||
"AIMessage": "ai",
|
||||
"AIMessageChunk": "ai",
|
||||
"Annotation": "content",
|
||||
"AudioContentBlock": "content",
|
||||
"BaseMessage": "base",
|
||||
"BaseMessageChunk": "base",
|
||||
"merge_content": "base",
|
||||
"message_to_dict": "base",
|
||||
"messages_to_dict": "base",
|
||||
"Citation": "content",
|
||||
"ContentBlock": "content",
|
||||
"ChatMessage": "chat",
|
||||
"ChatMessageChunk": "chat",
|
||||
"CodeInterpreterCall": "content",
|
||||
"CodeInterpreterOutput": "content",
|
||||
"CodeInterpreterResult": "content",
|
||||
"DataContentBlock": "content",
|
||||
"FileContentBlock": "content",
|
||||
"FunctionMessage": "function",
|
||||
"FunctionMessageChunk": "function",
|
||||
"HumanMessage": "human",
|
||||
"HumanMessageChunk": "human",
|
||||
"NonStandardAnnotation": "content",
|
||||
"NonStandardContentBlock": "content",
|
||||
"PlainTextContentBlock": "content",
|
||||
"ReasoningContentBlock": "content",
|
||||
"RemoveMessage": "modifier",
|
||||
"SystemMessage": "system",
|
||||
"SystemMessageChunk": "system",
|
||||
"WebSearchCall": "content",
|
||||
"WebSearchResult": "content",
|
||||
"ImageContentBlock": "content",
|
||||
"InvalidToolCall": "tool",
|
||||
"TextContentBlock": "content",
|
||||
"ToolCall": "tool",
|
||||
"ToolCallChunk": "tool",
|
||||
"ToolMessage": "tool",
|
||||
"ToolMessageChunk": "tool",
|
||||
"VideoContentBlock": "content",
|
||||
"AnyMessage": "utils",
|
||||
"MessageLikeRepresentation": "utils",
|
||||
"_message_from_dict": "utils",
|
||||
"convert_to_messages": "utils",
|
||||
"convert_to_openai_data_block": "content_blocks",
|
||||
"convert_to_openai_image_block": "content_blocks",
|
||||
"convert_to_openai_data_block": "content",
|
||||
"convert_to_openai_image_block": "content",
|
||||
"convert_to_openai_messages": "utils",
|
||||
"filter_messages": "utils",
|
||||
"get_buffer_string": "utils",
|
||||
"is_data_content_block": "content_blocks",
|
||||
"is_data_content_block": "content",
|
||||
"is_reasoning_block": "content",
|
||||
"is_text_block": "content",
|
||||
"is_tool_call_block": "content",
|
||||
"is_tool_call_chunk": "content",
|
||||
"merge_message_runs": "utils",
|
||||
"message_chunk_to_message": "utils",
|
||||
"messages_from_dict": "utils",
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
from typing import Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import NotRequired, Self, TypedDict, override
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
@@ -20,25 +21,17 @@ from langchain_core.messages.tool import (
|
||||
default_tool_chunk_parser,
|
||||
default_tool_parser,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
invalid_tool_call as create_invalid_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call_chunk as create_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.usage import _dict_int_op
|
||||
from langchain_core.utils.utils import LC_AUTO_PREFIX, LC_ID_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_LC_ID_PREFIX = "run-"
|
||||
|
||||
|
||||
class InputTokenDetails(TypedDict, total=False):
|
||||
"""Breakdown of input token counts.
|
||||
|
||||
@@ -180,16 +173,42 @@ class AIMessage(BaseMessage):
|
||||
type: Literal["ai"] = "ai"
|
||||
"""The type of the message (used for deserialization). Defaults to "ai"."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
self,
|
||||
content: Union[str, list[Union[str, dict]]],
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
Args:
|
||||
content: The content of the message.
|
||||
kwargs: Additional arguments to pass to the parent class.
|
||||
"""
|
||||
super().__init__(content=content, **kwargs)
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Specify ``content`` as positional arg or ``content_blocks`` for typing."""
|
||||
if content_blocks is not None:
|
||||
# If there are tool calls in content_blocks, but not in tool_calls, add them
|
||||
content_tool_calls = [
|
||||
block for block in content_blocks if block.get("type") == "tool_call"
|
||||
]
|
||||
if content_tool_calls and "tool_calls" not in kwargs:
|
||||
kwargs["tool_calls"] = content_tool_calls
|
||||
|
||||
super().__init__(
|
||||
content=cast("Union[str, list[Union[str, dict]]]", content_blocks),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> dict:
|
||||
@@ -199,6 +218,46 @@ class AIMessage(BaseMessage):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@property
|
||||
def content_blocks(self) -> list[types.ContentBlock]:
|
||||
"""Return content blocks of the message."""
|
||||
if self.response_metadata.get("output_version") == "v1":
|
||||
return cast("list[types.ContentBlock]", self.content)
|
||||
|
||||
model_provider = self.response_metadata.get("model_provider")
|
||||
if model_provider:
|
||||
from langchain_core.messages.block_translators import get_translator
|
||||
|
||||
translator = get_translator(model_provider)
|
||||
if translator:
|
||||
return translator["translate_content"](self)
|
||||
|
||||
# Otherwise, use best-effort parsing
|
||||
blocks = super().content_blocks
|
||||
|
||||
if self.tool_calls:
|
||||
# Add from tool_calls if missing from content
|
||||
content_tool_call_ids = {
|
||||
block.get("id")
|
||||
for block in self.content
|
||||
if isinstance(block, dict) and block.get("type") == "tool_call"
|
||||
}
|
||||
for tool_call in self.tool_calls:
|
||||
if (id_ := tool_call.get("id")) and id_ not in content_tool_call_ids:
|
||||
tool_call_block: types.ToolCall = {
|
||||
"type": "tool_call",
|
||||
"id": id_,
|
||||
"name": tool_call["name"],
|
||||
"args": tool_call["args"],
|
||||
}
|
||||
if "index" in tool_call:
|
||||
tool_call_block["index"] = tool_call["index"]
|
||||
if "extras" in tool_call:
|
||||
tool_call_block["extras"] = tool_call["extras"]
|
||||
blocks.append(tool_call_block)
|
||||
|
||||
return blocks
|
||||
|
||||
# TODO: remove this logic if possible, reducing breaking nature of changes
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -227,7 +286,9 @@ class AIMessage(BaseMessage):
|
||||
# Ensure "type" is properly set on all tool call-like dicts.
|
||||
if tool_calls := values.get("tool_calls"):
|
||||
values["tool_calls"] = [
|
||||
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||
create_tool_call(
|
||||
**{k: v for k, v in tc.items() if k not in ("type", "extras")}
|
||||
)
|
||||
for tc in tool_calls
|
||||
]
|
||||
if invalid_tool_calls := values.get("invalid_tool_calls"):
|
||||
@@ -306,6 +367,42 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@property
|
||||
def content_blocks(self) -> list[types.ContentBlock]:
|
||||
"""Return content blocks of the message."""
|
||||
if self.response_metadata.get("output_version") == "v1":
|
||||
return cast("list[types.ContentBlock]", self.content)
|
||||
|
||||
model_provider = self.response_metadata.get("model_provider")
|
||||
if model_provider:
|
||||
from langchain_core.messages.block_translators import get_translator
|
||||
|
||||
translator = get_translator(model_provider)
|
||||
if translator:
|
||||
return translator["translate_content_chunk"](self)
|
||||
|
||||
# Otherwise, use best-effort parsing
|
||||
blocks = super().content_blocks
|
||||
|
||||
if self.tool_call_chunks and not self.content:
|
||||
blocks = [
|
||||
block
|
||||
for block in blocks
|
||||
if block["type"] not in ("tool_call", "invalid_tool_call")
|
||||
]
|
||||
for tool_call_chunk in self.tool_call_chunks:
|
||||
tc: types.ToolCallChunk = {
|
||||
"type": "tool_call_chunk",
|
||||
"id": tool_call_chunk.get("id"),
|
||||
"name": tool_call_chunk.get("name"),
|
||||
"args": tool_call_chunk.get("args"),
|
||||
}
|
||||
if (idx := tool_call_chunk.get("index")) is not None:
|
||||
tc["index"] = idx
|
||||
blocks.append(tc)
|
||||
|
||||
return blocks
|
||||
|
||||
@model_validator(mode="after")
|
||||
def init_tool_calls(self) -> Self:
|
||||
"""Initialize tool calls from tool call chunks.
|
||||
@@ -431,17 +528,27 @@ def add_ai_message_chunks(
|
||||
|
||||
chunk_id = None
|
||||
candidates = [left.id] + [o.id for o in others]
|
||||
# first pass: pick the first non-run-* id
|
||||
# first pass: pick the first provider-assigned id (non-run-* and non-lc_*)
|
||||
for id_ in candidates:
|
||||
if id_ and not id_.startswith(_LC_ID_PREFIX):
|
||||
if (
|
||||
id_
|
||||
and not id_.startswith(LC_ID_PREFIX)
|
||||
and not id_.startswith(LC_AUTO_PREFIX)
|
||||
):
|
||||
chunk_id = id_
|
||||
break
|
||||
else:
|
||||
# second pass: no provider-assigned id found, just take the first non-null
|
||||
# second pass: prefer lc_run-* ids over lc_* ids
|
||||
for id_ in candidates:
|
||||
if id_:
|
||||
if id_ and id_.startswith(LC_ID_PREFIX):
|
||||
chunk_id = id_
|
||||
break
|
||||
else:
|
||||
# third pass: take any remaining id (auto-generated lc_* ids)
|
||||
for id_ in candidates:
|
||||
if id_:
|
||||
chunk_id = id_
|
||||
break
|
||||
|
||||
return left.__class__(
|
||||
example=left.example,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
@@ -14,6 +14,7 @@ from langchain_core.utils.interactive_env import is_interactive_env
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
@@ -61,15 +62,32 @@ class BaseMessage(Serializable):
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
self,
|
||||
content: Union[str, list[Union[str, dict]]],
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
Args:
|
||||
content: The string contents of the message.
|
||||
"""
|
||||
super().__init__(content=content, **kwargs)
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Specify ``content`` as positional arg or ``content_blocks`` for typing."""
|
||||
if content_blocks is not None:
|
||||
super().__init__(content=content_blocks, **kwargs)
|
||||
else:
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -88,6 +106,47 @@ class BaseMessage(Serializable):
|
||||
"""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@property
|
||||
def content_blocks(self) -> list[types.ContentBlock]:
|
||||
"""Return the content as a list of standard ``ContentBlock``s.
|
||||
|
||||
To use this property, the corresponding chat model must support
|
||||
``message_version='v1'`` or higher:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
llm = init_chat_model("...", message_version="v1")
|
||||
|
||||
Otherwise, does best-effort parsing to standard types.
|
||||
|
||||
"""
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
blocks: list[types.ContentBlock] = []
|
||||
content = (
|
||||
[self.content]
|
||||
if isinstance(self.content, str) and self.content
|
||||
else self.content
|
||||
)
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
blocks.append({"type": "text", "text": item})
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type not in types.KNOWN_BLOCK_TYPES:
|
||||
msg = (
|
||||
f"Non-standard content block type '{item_type}'. Ensure "
|
||||
"the model supports `output_version='v1'` or higher and "
|
||||
"that this attribute is set on initialization."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
blocks.append(cast("types.ContentBlock", item))
|
||||
else:
|
||||
pass
|
||||
|
||||
return blocks
|
||||
|
||||
def text(self) -> str:
|
||||
"""Get the text content of the message.
|
||||
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Derivations of standard content blocks from provider content."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
# Provider to translator mapping
|
||||
PROVIDER_TRANSLATORS: dict[str, dict[str, Callable[..., list[types.ContentBlock]]]] = {}
|
||||
|
||||
|
||||
def register_translator(
|
||||
provider: str,
|
||||
translate_content: Callable[[AIMessage], list[types.ContentBlock]],
|
||||
translate_content_chunk: Callable[[AIMessageChunk], list[types.ContentBlock]],
|
||||
) -> None:
|
||||
"""Register content translators for a provider.
|
||||
|
||||
Args:
|
||||
provider: The model provider name (e.g. ``'openai'``, ``'anthropic'``).
|
||||
translate_content: Function to translate ``AIMessage`` content.
|
||||
translate_content_chunk: Function to translate ``AIMessageChunk`` content.
|
||||
"""
|
||||
PROVIDER_TRANSLATORS[provider] = {
|
||||
"translate_content": translate_content,
|
||||
"translate_content_chunk": translate_content_chunk,
|
||||
}
|
||||
|
||||
|
||||
def get_translator(
|
||||
provider: str,
|
||||
) -> dict[str, Callable[..., list[types.ContentBlock]]] | None:
|
||||
"""Get the translator functions for a provider.
|
||||
|
||||
Args:
|
||||
provider: The model provider name.
|
||||
|
||||
Returns:
|
||||
Dictionary with ``'translate_content'`` and ``'translate_content_chunk'``
|
||||
functions, or None if no translator is registered for the provider.
|
||||
"""
|
||||
return PROVIDER_TRANSLATORS.get(provider)
|
||||
|
||||
|
||||
def _auto_register_translators() -> None:
|
||||
"""Automatically register all available block translators."""
|
||||
import contextlib
|
||||
import importlib
|
||||
import pkgutil
|
||||
from pathlib import Path
|
||||
|
||||
package_path = Path(__file__).parent
|
||||
|
||||
# Discover all sub-modules
|
||||
for module_info in pkgutil.iter_modules([str(package_path)]):
|
||||
module_name = module_info.name
|
||||
|
||||
# Skip the __init__ module and any private modules
|
||||
if module_name.startswith("_"):
|
||||
continue
|
||||
|
||||
if module_info.ispkg:
|
||||
# For subpackages, discover their submodules
|
||||
subpackage_path = package_path / module_name
|
||||
for submodule_info in pkgutil.iter_modules([str(subpackage_path)]):
|
||||
submodule_name = submodule_info.name
|
||||
if not submodule_name.startswith("_"):
|
||||
with contextlib.suppress(ImportError, AttributeError):
|
||||
importlib.import_module(
|
||||
f".{module_name}.{submodule_name}", package=__name__
|
||||
)
|
||||
else:
|
||||
# Import top-level translator modules
|
||||
with contextlib.suppress(ImportError, AttributeError):
|
||||
importlib.import_module(f".{module_name}", package=__name__)
|
||||
|
||||
|
||||
_auto_register_translators()
|
||||
@@ -0,0 +1 @@
|
||||
"""Derivations of standard content blocks from Amazon content."""
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Derivations of standard content blocks from Amazon (Bedrock) content."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with Bedrock content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a chunk with Bedrock content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _register_bedrock_translator() -> None:
|
||||
"""Register the Bedrock translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator(
|
||||
"amazon_bedrock_chat", translate_content, translate_content_chunk
|
||||
)
|
||||
|
||||
|
||||
_register_bedrock_translator()
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Derivations of standard content blocks from Amazon (Bedrock Converse) content."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with Bedrock Converse content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a chunk with Bedrock Converse content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _register_bedrock_converse_translator() -> None:
|
||||
"""Register the Bedrock Converse translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator(
|
||||
"amazon_bedrock_converse_chat", translate_content, translate_content_chunk
|
||||
)
|
||||
|
||||
|
||||
_register_bedrock_converse_translator()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Derivations of standard content blocks from Anthropic content."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with Anthropic content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message chunk with Anthropic content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _register_anthropic_translator() -> None:
|
||||
"""Register the Anthropic translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator("anthropic", translate_content, translate_content_chunk)
|
||||
|
||||
|
||||
_register_anthropic_translator()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Derivations of standard content blocks from Chroma content."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with Chroma content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message chunk with Chroma content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _register_chroma_translator() -> None:
|
||||
"""Register the Chroma translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator("chroma", translate_content, translate_content_chunk)
|
||||
|
||||
|
||||
_register_chroma_translator()
|
||||
@@ -0,0 +1 @@
|
||||
"""Derivations of standard content blocks from Google content."""
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Derivations of standard content blocks from Google (GenAI) content."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with Google (GenAI) content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a chunk with Google (GenAI) content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _register_google_genai_translator() -> None:
|
||||
"""Register the Google (GenAI) translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator("google_genai", translate_content, translate_content_chunk)
|
||||
|
||||
|
||||
_register_google_genai_translator()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Derivations of standard content blocks from Google (VertexAI) content."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with Google (VertexAI) content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a chunk with Google (VertexAI) content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _register_google_vertexai_translator() -> None:
|
||||
"""Register the Google (VertexAI) translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator("google_vertexai", translate_content, translate_content_chunk)
|
||||
|
||||
|
||||
_register_google_vertexai_translator()
|
||||
27
libs/core/langchain_core/messages/block_translators/groq.py
Normal file
27
libs/core/langchain_core/messages/block_translators/groq.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Derivations of standard content blocks from Groq content."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with Groq content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message chunk with Groq content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _register_groq_translator() -> None:
|
||||
"""Register the Groq translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator("groq", translate_content, translate_content_chunk)
|
||||
|
||||
|
||||
_register_groq_translator()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Derivations of standard content blocks from Ollama content."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with Ollama content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message chunk with Ollama content."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _register_ollama_translator() -> None:
|
||||
"""Register the Ollama translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator("ollama", translate_content, translate_content_chunk)
|
||||
|
||||
|
||||
_register_ollama_translator()
|
||||
358
libs/core/langchain_core/messages/block_translators/openai.py
Normal file
358
libs/core/langchain_core/messages/block_translators/openai.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Derivations of standard content blocks from OpenAI content."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
# v1 / Chat Completions
|
||||
def _convert_to_v1_from_chat_completions(
|
||||
message: AIMessage,
|
||||
) -> list[types.ContentBlock]:
|
||||
"""Mutate a Chat Completions message to v1 format."""
|
||||
content_blocks: list[types.ContentBlock] = []
|
||||
if isinstance(message.content, str):
|
||||
if message.content:
|
||||
content_blocks = [{"type": "text", "text": message.content}]
|
||||
else:
|
||||
content_blocks = []
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
content_blocks.append(tool_call)
|
||||
|
||||
return content_blocks
|
||||
|
||||
|
||||
def _convert_to_v1_from_chat_completions_chunk(
|
||||
chunk: AIMessageChunk,
|
||||
) -> list[types.ContentBlock]:
|
||||
"""Mutate a Chat Completions chunk to v1 format."""
|
||||
content_blocks: list[types.ContentBlock] = []
|
||||
if isinstance(chunk.content, str):
|
||||
if chunk.content:
|
||||
content_blocks = [{"type": "text", "text": chunk.content}]
|
||||
else:
|
||||
content_blocks = []
|
||||
|
||||
for tool_call_chunk in chunk.tool_call_chunks:
|
||||
tc: types.ToolCallChunk = {
|
||||
"type": "tool_call_chunk",
|
||||
"id": tool_call_chunk.get("id"),
|
||||
"name": tool_call_chunk.get("name"),
|
||||
"args": tool_call_chunk.get("args"),
|
||||
}
|
||||
if (idx := tool_call_chunk.get("index")) is not None:
|
||||
tc["index"] = idx
|
||||
content_blocks.append(tc)
|
||||
|
||||
return content_blocks
|
||||
|
||||
|
||||
def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage:
|
||||
"""Convert a v1 message to the Chat Completions format."""
|
||||
if isinstance(message.content, list):
|
||||
new_content: list = []
|
||||
for block in message.content:
|
||||
if isinstance(block, dict):
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
# Strip annotations
|
||||
new_content.append({"type": "text", "text": block["text"]})
|
||||
elif block_type in ("reasoning", "tool_call"):
|
||||
pass
|
||||
else:
|
||||
new_content.append(block)
|
||||
else:
|
||||
new_content.append(block)
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
return message
|
||||
|
||||
|
||||
# v1 / Responses
|
||||
def _convert_annotation_to_v1(annotation: dict[str, Any]) -> types.Annotation:
|
||||
annotation_type = annotation.get("type")
|
||||
|
||||
if annotation_type == "url_citation":
|
||||
known_fields = {
|
||||
"type",
|
||||
"url",
|
||||
"title",
|
||||
"cited_text",
|
||||
"start_index",
|
||||
"end_index",
|
||||
}
|
||||
url_citation = cast("types.Citation", {})
|
||||
for field in ("end_index", "start_index", "title"):
|
||||
if field in annotation:
|
||||
url_citation[field] = annotation[field]
|
||||
url_citation["type"] = "citation"
|
||||
url_citation["url"] = annotation["url"]
|
||||
for field, value in annotation.items():
|
||||
if field not in known_fields:
|
||||
if "extras" not in url_citation:
|
||||
url_citation["extras"] = {}
|
||||
url_citation["extras"][field] = value
|
||||
return url_citation
|
||||
|
||||
if annotation_type == "file_citation":
|
||||
known_fields = {
|
||||
"type",
|
||||
"title",
|
||||
"cited_text",
|
||||
"start_index",
|
||||
"end_index",
|
||||
"filename",
|
||||
}
|
||||
document_citation: types.Citation = {"type": "citation"}
|
||||
if "filename" in annotation:
|
||||
document_citation["title"] = annotation["filename"]
|
||||
for field, value in annotation.items():
|
||||
if field not in known_fields:
|
||||
if "extras" not in document_citation:
|
||||
document_citation["extras"] = {}
|
||||
document_citation["extras"][field] = value
|
||||
|
||||
return document_citation
|
||||
|
||||
# TODO: standardise container_file_citation?
|
||||
non_standard_annotation: types.NonStandardAnnotation = {
|
||||
"type": "non_standard_annotation",
|
||||
"value": annotation,
|
||||
}
|
||||
return non_standard_annotation
|
||||
|
||||
|
||||
def _explode_reasoning(block: dict[str, Any]) -> Iterable[types.ReasoningContentBlock]:
|
||||
if "summary" not in block:
|
||||
yield cast("types.ReasoningContentBlock", block)
|
||||
return
|
||||
|
||||
known_fields = {"type", "reasoning", "id", "index"}
|
||||
unknown_fields = [
|
||||
field for field in block if field != "summary" and field not in known_fields
|
||||
]
|
||||
if unknown_fields:
|
||||
block["extras"] = {}
|
||||
for field in unknown_fields:
|
||||
block["extras"][field] = block.pop(field)
|
||||
|
||||
if not block["summary"]:
|
||||
# [{'id': 'rs_...', 'summary': [], 'type': 'reasoning', 'index': 0}]
|
||||
block = {k: v for k, v in block.items() if k != "summary"}
|
||||
if "index" in block:
|
||||
meaningful_idx = f"{block['index']}_0"
|
||||
block["index"] = f"lc_rs_{meaningful_idx.encode().hex()}"
|
||||
yield cast("types.ReasoningContentBlock", block)
|
||||
return
|
||||
|
||||
# Common part for every exploded line, except 'summary'
|
||||
common = {k: v for k, v in block.items() if k in known_fields}
|
||||
|
||||
# Optional keys that must appear only in the first exploded item
|
||||
first_only = block.pop("extras", None)
|
||||
|
||||
for idx, part in enumerate(block["summary"]):
|
||||
new_block = dict(common)
|
||||
new_block["reasoning"] = part.get("text", "")
|
||||
if idx == 0 and first_only:
|
||||
new_block.update(first_only)
|
||||
if "index" in new_block:
|
||||
summary_index = part.get("index", 0)
|
||||
meaningful_idx = f"{new_block['index']}_{summary_index}"
|
||||
new_block["index"] = f"lc_rs_{meaningful_idx.encode().hex()}"
|
||||
|
||||
yield cast("types.ReasoningContentBlock", new_block)
|
||||
|
||||
|
||||
def _convert_to_v1_from_responses(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Convert a Responses message to v1 format."""
|
||||
|
||||
def _iter_blocks() -> Iterable[types.ContentBlock]:
|
||||
for raw_block in message.content:
|
||||
if not isinstance(raw_block, dict):
|
||||
continue
|
||||
block = raw_block.copy()
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == "text":
|
||||
if "text" not in block:
|
||||
block["text"] = ""
|
||||
if "annotations" in block:
|
||||
block["annotations"] = [
|
||||
_convert_annotation_to_v1(a) for a in block["annotations"]
|
||||
]
|
||||
if "index" in block:
|
||||
block["index"] = f"lc_txt_{block['index']}"
|
||||
yield cast("types.TextContentBlock", block)
|
||||
|
||||
elif block_type == "reasoning":
|
||||
yield from _explode_reasoning(block)
|
||||
|
||||
elif block_type == "image_generation_call" and (
|
||||
result := block.get("result")
|
||||
):
|
||||
new_block = {"type": "image", "base64": result}
|
||||
if output_format := block.get("output_format"):
|
||||
new_block["mime_type"] = f"image/{output_format}"
|
||||
if "id" in block:
|
||||
new_block["id"] = block["id"]
|
||||
if "index" in block:
|
||||
new_block["index"] = f"lc_img_{block['index']}"
|
||||
for extra_key in (
|
||||
"status",
|
||||
"background",
|
||||
"output_format",
|
||||
"quality",
|
||||
"revised_prompt",
|
||||
"size",
|
||||
):
|
||||
if extra_key in block:
|
||||
if "extras" not in new_block:
|
||||
new_block["extras"] = {}
|
||||
new_block["extras"][extra_key] = block[extra_key]
|
||||
yield cast("types.ImageContentBlock", new_block)
|
||||
|
||||
elif block_type == "function_call":
|
||||
tool_call_block: Optional[
|
||||
Union[types.ToolCall, types.InvalidToolCall, types.ToolCallChunk]
|
||||
] = None
|
||||
call_id = block.get("call_id", "")
|
||||
if (
|
||||
isinstance(message, AIMessageChunk)
|
||||
and len(message.tool_call_chunks) == 1
|
||||
):
|
||||
tool_call_block = message.tool_call_chunks[0].copy() # type: ignore[assignment]
|
||||
elif call_id:
|
||||
for tool_call in message.tool_calls or []:
|
||||
if tool_call.get("id") == call_id:
|
||||
tool_call_block = tool_call.copy()
|
||||
break
|
||||
else:
|
||||
for invalid_tool_call in message.invalid_tool_calls or []:
|
||||
if invalid_tool_call.get("id") == call_id:
|
||||
tool_call_block = invalid_tool_call.copy()
|
||||
break
|
||||
else:
|
||||
pass
|
||||
if tool_call_block:
|
||||
if "id" in block:
|
||||
if "extras" not in tool_call_block:
|
||||
tool_call_block["extras"] = {}
|
||||
tool_call_block["extras"]["item_id"] = block["id"]
|
||||
if "index" in block:
|
||||
tool_call_block["index"] = f"lc_tc_{block['index']}"
|
||||
yield tool_call_block
|
||||
|
||||
elif block_type == "web_search_call":
|
||||
web_search_call = {"type": "web_search_call", "id": block["id"]}
|
||||
if "index" in block:
|
||||
web_search_call["index"] = f"lc_wsc_{block['index']}"
|
||||
if (
|
||||
"action" in block
|
||||
and isinstance(block["action"], dict)
|
||||
and block["action"].get("type") == "search"
|
||||
and "query" in block["action"]
|
||||
):
|
||||
web_search_call["query"] = block["action"]["query"]
|
||||
for key in block:
|
||||
if key not in ("type", "id", "index"):
|
||||
web_search_call[key] = block[key]
|
||||
|
||||
yield cast("types.WebSearchCall", web_search_call)
|
||||
|
||||
# If .content already has web_search_result, don't add
|
||||
if not any(
|
||||
isinstance(other_block, dict)
|
||||
and other_block.get("type") == "web_search_result"
|
||||
and other_block.get("id") == block["id"]
|
||||
for other_block in message.content
|
||||
):
|
||||
web_search_result = {"type": "web_search_result", "id": block["id"]}
|
||||
if "index" in block and isinstance(block["index"], int):
|
||||
web_search_result["index"] = f"lc_wsr_{block['index'] + 1}"
|
||||
yield cast("types.WebSearchResult", web_search_result)
|
||||
|
||||
elif block_type == "code_interpreter_call":
|
||||
code_interpreter_call = {
|
||||
"type": "code_interpreter_call",
|
||||
"id": block["id"],
|
||||
}
|
||||
if "code" in block:
|
||||
code_interpreter_call["code"] = block["code"]
|
||||
if "index" in block:
|
||||
code_interpreter_call["index"] = f"lc_cic_{block['index']}"
|
||||
known_fields = {"type", "id", "language", "code", "extras", "index"}
|
||||
for key in block:
|
||||
if key not in known_fields:
|
||||
if "extras" not in code_interpreter_call:
|
||||
code_interpreter_call["extras"] = {}
|
||||
code_interpreter_call["extras"][key] = block[key]
|
||||
|
||||
code_interpreter_result = {
|
||||
"type": "code_interpreter_result",
|
||||
"id": block["id"],
|
||||
}
|
||||
if "outputs" in block:
|
||||
code_interpreter_result["outputs"] = block["outputs"]
|
||||
for output in block["outputs"]:
|
||||
if (
|
||||
isinstance(output, dict)
|
||||
and (output_type := output.get("type"))
|
||||
and output_type == "logs"
|
||||
):
|
||||
if "output" not in code_interpreter_result:
|
||||
code_interpreter_result["output"] = []
|
||||
code_interpreter_result["output"].append(
|
||||
{
|
||||
"type": "code_interpreter_output",
|
||||
"stdout": output.get("logs", ""),
|
||||
}
|
||||
)
|
||||
|
||||
if "status" in block:
|
||||
code_interpreter_result["status"] = block["status"]
|
||||
if "index" in block and isinstance(block["index"], int):
|
||||
code_interpreter_result["index"] = f"lc_cir_{block['index'] + 1}"
|
||||
|
||||
yield cast("types.CodeInterpreterCall", code_interpreter_call)
|
||||
yield cast("types.CodeInterpreterResult", code_interpreter_result)
|
||||
|
||||
elif block_type in types.KNOWN_BLOCK_TYPES:
|
||||
yield cast("types.ContentBlock", block)
|
||||
else:
|
||||
new_block = {"type": "non_standard", "value": block}
|
||||
if "index" in new_block["value"]:
|
||||
new_block["index"] = f"lc_ns_{new_block['value'].pop('index')}"
|
||||
yield cast("types.NonStandardContentBlock", new_block)
|
||||
|
||||
return list(_iter_blocks())
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with OpenAI content."""
|
||||
if isinstance(message.content, str):
|
||||
return _convert_to_v1_from_chat_completions(message)
|
||||
return _convert_to_v1_from_responses(message)
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message chunk with OpenAI content."""
|
||||
if isinstance(message.content, str):
|
||||
return _convert_to_v1_from_chat_completions_chunk(message)
|
||||
return _convert_to_v1_from_responses(message)
|
||||
|
||||
|
||||
def _register_openai_translator() -> None:
|
||||
"""Register the OpenAI translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
register_translator("openai", translate_content, translate_content_chunk)
|
||||
|
||||
|
||||
_register_openai_translator()
|
||||
1564
libs/core/langchain_core/messages/content.py
Normal file
1564
libs/core/langchain_core/messages/content.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,155 +0,0 @@
|
||||
"""Types for content blocks."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
|
||||
class BaseDataContentBlock(TypedDict, total=False):
|
||||
"""Base class for data content blocks."""
|
||||
|
||||
mime_type: NotRequired[str]
|
||||
"""MIME type of the content block (if needed)."""
|
||||
|
||||
|
||||
class URLContentBlock(BaseDataContentBlock):
|
||||
"""Content block for data from a URL."""
|
||||
|
||||
type: Literal["image", "audio", "file"]
|
||||
"""Type of the content block."""
|
||||
source_type: Literal["url"]
|
||||
"""Source type (url)."""
|
||||
url: str
|
||||
"""URL for data."""
|
||||
|
||||
|
||||
class Base64ContentBlock(BaseDataContentBlock):
|
||||
"""Content block for inline data from a base64 string."""
|
||||
|
||||
type: Literal["image", "audio", "file"]
|
||||
"""Type of the content block."""
|
||||
source_type: Literal["base64"]
|
||||
"""Source type (base64)."""
|
||||
data: str
|
||||
"""Data as a base64 string."""
|
||||
|
||||
|
||||
class PlainTextContentBlock(BaseDataContentBlock):
|
||||
"""Content block for plain text data (e.g., from a document)."""
|
||||
|
||||
type: Literal["file"]
|
||||
"""Type of the content block."""
|
||||
source_type: Literal["text"]
|
||||
"""Source type (text)."""
|
||||
text: str
|
||||
"""Text data."""
|
||||
|
||||
|
||||
class IDContentBlock(TypedDict):
|
||||
"""Content block for data specified by an identifier."""
|
||||
|
||||
type: Literal["image", "audio", "file"]
|
||||
"""Type of the content block."""
|
||||
source_type: Literal["id"]
|
||||
"""Source type (id)."""
|
||||
id: str
|
||||
"""Identifier for data source."""
|
||||
|
||||
|
||||
DataContentBlock = Union[
|
||||
URLContentBlock,
|
||||
Base64ContentBlock,
|
||||
PlainTextContentBlock,
|
||||
IDContentBlock,
|
||||
]
|
||||
|
||||
_DataContentBlockAdapter: TypeAdapter[DataContentBlock] = TypeAdapter(DataContentBlock)
|
||||
|
||||
|
||||
def is_data_content_block(
|
||||
content_block: dict,
|
||||
) -> bool:
|
||||
"""Check if the content block is a standard data content block.
|
||||
|
||||
Args:
|
||||
content_block: The content block to check.
|
||||
|
||||
Returns:
|
||||
True if the content block is a data content block, False otherwise.
|
||||
"""
|
||||
try:
|
||||
_ = _DataContentBlockAdapter.validate_python(content_block)
|
||||
except ValidationError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def convert_to_openai_image_block(content_block: dict[str, Any]) -> dict:
|
||||
"""Convert image content block to format expected by OpenAI Chat Completions API."""
|
||||
if content_block["source_type"] == "url":
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": content_block["url"],
|
||||
},
|
||||
}
|
||||
if content_block["source_type"] == "base64":
|
||||
if "mime_type" not in content_block:
|
||||
error_message = "mime_type key is required for base64 data."
|
||||
raise ValueError(error_message)
|
||||
mime_type = content_block["mime_type"]
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{content_block['data']}",
|
||||
},
|
||||
}
|
||||
error_message = "Unsupported source type. Only 'url' and 'base64' are supported."
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def convert_to_openai_data_block(block: dict) -> dict:
|
||||
"""Format standard data content block to format expected by OpenAI."""
|
||||
if block["type"] == "image":
|
||||
formatted_block = convert_to_openai_image_block(block)
|
||||
|
||||
elif block["type"] == "file":
|
||||
if block["source_type"] == "base64":
|
||||
file = {"file_data": f"data:{block['mime_type']};base64,{block['data']}"}
|
||||
if filename := block.get("filename"):
|
||||
file["filename"] = filename
|
||||
elif (metadata := block.get("metadata")) and ("filename" in metadata):
|
||||
file["filename"] = metadata["filename"]
|
||||
else:
|
||||
warnings.warn(
|
||||
"OpenAI may require a filename for file inputs. Specify a filename "
|
||||
"in the content block: {'type': 'file', 'source_type': 'base64', "
|
||||
"'mime_type': 'application/pdf', 'data': '...', "
|
||||
"'filename': 'my-pdf'}",
|
||||
stacklevel=1,
|
||||
)
|
||||
formatted_block = {"type": "file", "file": file}
|
||||
elif block["source_type"] == "id":
|
||||
formatted_block = {"type": "file", "file": {"file_id": block["id"]}}
|
||||
else:
|
||||
error_msg = "source_type base64 or id is required for file blocks."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
elif block["type"] == "audio":
|
||||
if block["source_type"] == "base64":
|
||||
audio_format = block["mime_type"].split("/")[-1]
|
||||
formatted_block = {
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": block["data"], "format": audio_format},
|
||||
}
|
||||
else:
|
||||
error_msg = "source_type base64 is required for audio blocks."
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
error_msg = f"Block of type {block['type']} is not supported."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return formatted_block
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Human message."""
|
||||
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
|
||||
@@ -41,16 +42,35 @@ class HumanMessage(BaseMessage):
|
||||
type: Literal["human"] = "human"
|
||||
"""The type of the message (used for serialization). Defaults to "human"."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
self,
|
||||
content: Union[str, list[Union[str, dict]]],
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
Args:
|
||||
content: The string contents of the message.
|
||||
kwargs: Additional fields to pass to the message.
|
||||
"""
|
||||
super().__init__(content=content, **kwargs)
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Specify ``content`` as positional arg or ``content_blocks`` for typing."""
|
||||
if content_blocks is not None:
|
||||
super().__init__(
|
||||
content=cast("Union[str, list[Union[str, dict]]]", content_blocks),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""System message."""
|
||||
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
|
||||
@@ -34,16 +35,35 @@ class SystemMessage(BaseMessage):
|
||||
type: Literal["system"] = "system"
|
||||
"""The type of the message (used for serialization). Defaults to "system"."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
self,
|
||||
content: Union[str, list[Union[str, dict]]],
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
Args:
|
||||
content: The string contents of the message.
|
||||
kwargs: Additional fields to pass to the message.
|
||||
"""
|
||||
super().__init__(content=content, **kwargs)
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Specify ``content`` as positional arg or ``content_blocks`` for typing."""
|
||||
if content_blocks is not None:
|
||||
super().__init__(
|
||||
content=cast("Union[str, list[Union[str, dict]]]", content_blocks),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""Messages for tools."""
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union, cast, overload
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import NotRequired, TypedDict, override
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
from langchain_core.messages.content import InvalidToolCall as InvalidToolCall
|
||||
from langchain_core.messages.content import ToolCall as ToolCall
|
||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||
|
||||
|
||||
@@ -133,16 +136,35 @@ class ToolMessage(BaseMessage, ToolOutputMixin):
|
||||
values["tool_call_id"] = str(tool_call_id)
|
||||
return values
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Create a ToolMessage.
|
||||
self,
|
||||
content: Union[str, list[Union[str, dict]]],
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
Args:
|
||||
content: The string contents of the message.
|
||||
**kwargs: Additional fields.
|
||||
"""
|
||||
super().__init__(content=content, **kwargs)
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[Union[str, list[Union[str, dict]]]] = None,
|
||||
content_blocks: Optional[list[types.ContentBlock]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Specify ``content`` as positional arg or ``content_blocks`` for typing."""
|
||||
if content_blocks is not None:
|
||||
super().__init__(
|
||||
content=cast("Union[str, list[Union[str, dict]]]", content_blocks),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
@@ -177,37 +199,6 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
"""Represents a request to call a tool.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"name": "foo",
|
||||
"args": {"a": 1},
|
||||
"id": "123"
|
||||
}
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""The name of the tool to be called."""
|
||||
args: dict[str, Any]
|
||||
"""The arguments to the tool call."""
|
||||
id: Optional[str]
|
||||
"""An identifier associated with the tool call.
|
||||
|
||||
An identifier is needed to associate a tool call request with a tool
|
||||
call result in events when multiple concurrent tool calls are made.
|
||||
"""
|
||||
type: NotRequired[Literal["tool_call"]]
|
||||
|
||||
|
||||
def tool_call(
|
||||
*,
|
||||
name: str,
|
||||
@@ -276,24 +267,6 @@ def tool_call_chunk(
|
||||
)
|
||||
|
||||
|
||||
class InvalidToolCall(TypedDict):
|
||||
"""Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
"""
|
||||
|
||||
name: Optional[str]
|
||||
"""The name of the tool to be called."""
|
||||
args: Optional[str]
|
||||
"""The arguments to the tool call."""
|
||||
id: Optional[str]
|
||||
"""An identifier associated with the tool call."""
|
||||
error: Optional[str]
|
||||
"""An error message associated with the tool call."""
|
||||
type: NotRequired[Literal["invalid_tool_call"]]
|
||||
|
||||
|
||||
def invalid_tool_call(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
|
||||
@@ -31,10 +31,13 @@ from typing import (
|
||||
from pydantic import Discriminator, Field, Tag
|
||||
|
||||
from langchain_core.exceptions import ErrorCode, create_message
|
||||
from langchain_core.messages import convert_to_openai_data_block, is_data_content_block
|
||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||
from langchain_core.messages.content import (
|
||||
convert_to_openai_data_block,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||
from langchain_core.messages.modifier import RemoveMessage
|
||||
|
||||
@@ -123,7 +123,7 @@ class ImagePromptValue(PromptValue):
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt (image URL) as string."""
|
||||
return self.image_url["url"]
|
||||
return self.image_url.get("url", "")
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt (image URL) as messages."""
|
||||
|
||||
@@ -2399,7 +2399,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
description: The description of the tool. Defaults to None.
|
||||
arg_types: A dictionary of argument names to types. Defaults to None.
|
||||
message_version: Version of ``ToolMessage`` to return given
|
||||
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
|
||||
:class:`~langchain_core.messages.content.ToolCall` input.
|
||||
|
||||
Returns:
|
||||
A ``BaseTool`` instance.
|
||||
|
||||
@@ -44,6 +44,17 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
|
||||
)
|
||||
raise TypeError(msg)
|
||||
elif isinstance(merged[right_k], str):
|
||||
# Special handling for output_version - it should be consistent
|
||||
if right_k == "output_version":
|
||||
if merged[right_k] == right_v:
|
||||
continue
|
||||
msg = (
|
||||
"Unable to merge. Two different values seen for "
|
||||
f"'output_version': {merged[right_k]} and {right_v}. "
|
||||
"'output_version' should have the same value across "
|
||||
"all chunks in a generation."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
# TODO: Add below special handling for 'type' key in 0.3 and remove
|
||||
# merge_lists 'type' logic.
|
||||
#
|
||||
@@ -57,6 +68,11 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
|
||||
# "should either occur once or have the same value across "
|
||||
# "all dicts."
|
||||
# )
|
||||
if (right_k == "index" and merged[right_k].startswith("lc_")) or (
|
||||
right_k in ("id", "output_version", "model_provider")
|
||||
and merged[right_k] == right_v
|
||||
):
|
||||
continue
|
||||
merged[right_k] += right_v
|
||||
elif isinstance(merged[right_k], dict):
|
||||
merged[right_k] = merge_dicts(merged[right_k], right_v)
|
||||
@@ -93,7 +109,16 @@ def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list]
|
||||
merged = other.copy()
|
||||
else:
|
||||
for e in other:
|
||||
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
|
||||
if (
|
||||
isinstance(e, dict)
|
||||
and "index" in e
|
||||
and (
|
||||
isinstance(e["index"], int)
|
||||
or (
|
||||
isinstance(e["index"], str) and e["index"].startswith("lc_")
|
||||
)
|
||||
)
|
||||
):
|
||||
to_merge = [
|
||||
i
|
||||
for i, e_left in enumerate(merged)
|
||||
|
||||
@@ -9,6 +9,7 @@ import warnings
|
||||
from collections.abc import Iterator, Sequence
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Optional, Union, overload
|
||||
from uuid import uuid4
|
||||
|
||||
from packaging.version import parse
|
||||
from pydantic import SecretStr
|
||||
@@ -466,3 +467,31 @@ def secret_from_env(
|
||||
raise ValueError(msg)
|
||||
|
||||
return get_secret_from_env
|
||||
|
||||
|
||||
LC_AUTO_PREFIX = "lc_"
|
||||
"""LangChain auto-generated ID prefix for messages and content blocks."""
|
||||
|
||||
LC_ID_PREFIX = "lc_run-"
|
||||
"""Internal tracing/callback system identifier.
|
||||
|
||||
Used for:
|
||||
- Tracing. Every LangChain operation (LLM call, chain execution, tool use, etc.)
|
||||
gets a unique run_id (UUID)
|
||||
- Enables tracking parent-child relationships between operations
|
||||
"""
|
||||
|
||||
|
||||
def ensure_id(id_val: Optional[str]) -> str:
|
||||
"""Ensure the ID is a valid string, generating a new UUID if not provided.
|
||||
|
||||
Auto-generated UUIDs are prefixed by ``'lc_'`` to indicate they are
|
||||
LangChain-generated IDs.
|
||||
|
||||
Args:
|
||||
id_val: Optional string ID value to validate.
|
||||
|
||||
Returns:
|
||||
A string ID, either the validated provided value or a newly generated UUID4.
|
||||
"""
|
||||
return id_val or str(f"{LC_AUTO_PREFIX}{uuid4()}")
|
||||
|
||||
@@ -77,8 +77,12 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
_any_id_ai_message_chunk(content="", additional_kwargs={"foo": 42}),
|
||||
_any_id_ai_message_chunk(content="", additional_kwargs={"bar": 24}),
|
||||
_any_id_ai_message_chunk(
|
||||
content="", additional_kwargs={"foo": 42, "output_version": "v0"}
|
||||
),
|
||||
_any_id_ai_message_chunk(
|
||||
content="", additional_kwargs={"bar": 24, "output_version": "v0"}
|
||||
),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@@ -97,21 +101,31 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
|
||||
assert chunks == [
|
||||
_any_id_ai_message_chunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"name": "move_file"},
|
||||
"output_version": "v0",
|
||||
},
|
||||
),
|
||||
_any_id_ai_message_chunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'},
|
||||
"output_version": "v0",
|
||||
},
|
||||
),
|
||||
_any_id_ai_message_chunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": ","},
|
||||
"output_version": "v0",
|
||||
},
|
||||
),
|
||||
_any_id_ai_message_chunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
|
||||
"output_version": "v0",
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -131,7 +145,8 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
"name": "move_file",
|
||||
"arguments": '{\n "source_path": "foo",\n "'
|
||||
'destination_path": "bar"\n}',
|
||||
}
|
||||
},
|
||||
"output_version": "v0",
|
||||
},
|
||||
id=chunks[0].id,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
||||
from langchain_core.language_models import (
|
||||
BaseChatModel,
|
||||
FakeListChatModel,
|
||||
@@ -239,7 +240,9 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None, # type: ignore[override]
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: Optional[str] = "v0",
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Stream the output of the model."""
|
||||
|
||||
@@ -301,8 +301,9 @@ def test_llm_representation_for_serializable() -> None:
|
||||
assert chat._get_llm_string() == (
|
||||
'{"id": ["tests", "unit_tests", "language_models", "chat_models", '
|
||||
'"test_cache", "CustomChat"], "kwargs": {"messages": {"id": '
|
||||
'["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": '
|
||||
'1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]'
|
||||
'["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}, '
|
||||
'"output_version": "v0"}, "lc": 1, "name": "CustomChat", "type": '
|
||||
"\"constructor\"}---[('stop', None)]"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,570 @@
|
||||
"""Test output_version functionality in BaseChatModel."""
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
class MockChatModel(BaseChatModel):
|
||||
"""Mock chat model to test output_version functionality."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Store the output_version that was passed."""
|
||||
self.last_output_version = output_version
|
||||
message = AIMessage(content="test response")
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Store the output_version that was passed."""
|
||||
self.last_output_version = output_version
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="test"))
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=" stream"))
|
||||
|
||||
@override
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None, # type: ignore[override]
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Store the output_version that was passed."""
|
||||
self.last_output_version = output_version
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="async"))
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=" stream"))
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mock-chat-model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages() -> list[BaseMessage]:
|
||||
return [HumanMessage("Hello")]
|
||||
|
||||
|
||||
def test_invoke_uses_default_output_version(messages: list[BaseMessage]) -> None:
|
||||
"""Test `invoke()` uses the model's default `output_version` when not specified."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
model.invoke(messages)
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
def test_invoke_uses_provided_output_version(messages: list[BaseMessage]) -> None:
|
||||
"""Test that `invoke()` uses the provided `output_version` parameter."""
|
||||
model = MockChatModel(output_version="v0")
|
||||
model.invoke(messages, output_version="v1")
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
def test_invoke_output_version_none_uses_default(messages: list[BaseMessage]) -> None:
|
||||
"""Test that passing `output_version=None` uses the model's default."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
model.invoke(messages, output_version=None)
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
async def test_ainvoke_uses_default_output_version(messages: list[BaseMessage]) -> None:
|
||||
"""Test `ainvoke()` uses the model's default `output_version` when not specified."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
await model.ainvoke(messages)
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
async def test_ainvoke_uses_provided_output_version(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test that `ainvoke()` uses the provided `output_version` parameter."""
|
||||
model = MockChatModel(output_version="v0")
|
||||
await model.ainvoke(messages, output_version="v1")
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
async def test_ainvoke_output_version_none_uses_default(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test that passing `output_version=None` uses the model's default."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
await model.ainvoke(messages, output_version=None)
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
def test_stream_uses_default_output_version(messages: list[BaseMessage]) -> None:
|
||||
"""Test `stream()` uses the model's default `output_version` when not specified."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
list(model.stream(messages))
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
def test_stream_uses_provided_output_version(messages: list[BaseMessage]) -> None:
|
||||
"""Test that `stream()` uses the provided `output_version` parameter."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
list(model.stream(messages, output_version="v2"))
|
||||
assert model.last_output_version == "v2"
|
||||
|
||||
|
||||
def test_stream_output_version_none_uses_default(messages: list[BaseMessage]) -> None:
|
||||
"""Test that passing `output_version=None` uses the model's default."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
list(model.stream(messages, output_version=None))
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
async def test_astream_uses_default_output_version(messages: list[BaseMessage]) -> None:
|
||||
"""Test `astream()` uses the model's default `output_version` when not specified."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
async for _ in model.astream(messages):
|
||||
pass
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
async def test_astream_uses_provided_output_version(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test that `astream()` uses the provided `output_version` parameter."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
async for _ in model.astream(messages, output_version="v0"):
|
||||
pass
|
||||
assert model.last_output_version == "v0"
|
||||
|
||||
|
||||
async def test_astream_output_version_none_uses_default(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test that passing `output_version=None` uses the model's default."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
async for _ in model.astream(messages, output_version=None):
|
||||
pass
|
||||
assert model.last_output_version == "v1"
|
||||
|
||||
|
||||
def test_stream_fallback_to_invoke_passes_output_version(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test that `stream()` fallback to invoke passes the `output_version` correctly."""
|
||||
|
||||
class NoStreamModel(BaseChatModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
self.last_output_version = output_version
|
||||
message = AIMessage(content="test response")
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "no-stream-model"
|
||||
|
||||
model = NoStreamModel(output_version="v1")
|
||||
# Stream should fallback to invoke and pass the output_version
|
||||
list(model.stream(messages, output_version="v2"))
|
||||
assert model.last_output_version == "v2"
|
||||
|
||||
|
||||
async def test_astream_fallback_to_ainvoke_passes_output_version(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test `astream()` fallback to ainvoke passes the `output_version` correctly."""
|
||||
|
||||
class NoStreamModel(BaseChatModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
self.last_output_version = output_version
|
||||
message = AIMessage(content="test response")
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "no-stream-model"
|
||||
|
||||
model = NoStreamModel(output_version="v1")
|
||||
# Astream should fallback to ainvoke and pass the output_version
|
||||
async for _ in model.astream(messages, output_version="v2"):
|
||||
pass
|
||||
assert model.last_output_version == "v2"
|
||||
|
||||
|
||||
def test_generate_prompt_passes_output_version_to_internal_methods(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test `generate_prompt()` passes `output_version` to internal `_generate()`."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
|
||||
# Mock the _generate method to verify it receives the output_version
|
||||
with patch.object(model, "_generate", wraps=model._generate) as mock_generate:
|
||||
model.invoke(messages, output_version="v2")
|
||||
mock_generate.assert_called_once()
|
||||
# Verify that _generate was called with output_version="v2"
|
||||
call_kwargs = mock_generate.call_args.kwargs
|
||||
assert call_kwargs.get("output_version") == "v2"
|
||||
|
||||
|
||||
async def test_agenerate_prompt_passes_output_version_to_internal_methods(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test `agenerate_prompt()` passes output_version to internal `_agenerate()`."""
|
||||
model = MockChatModel(output_version="v1")
|
||||
|
||||
# Mock the _agenerate method to verify it receives the output_version
|
||||
with patch.object(model, "_agenerate", wraps=model._agenerate) as mock_agenerate:
|
||||
await model.ainvoke(messages, output_version="v2")
|
||||
mock_agenerate.assert_called_once()
|
||||
# Verify that _agenerate was called with output_version="v2"
|
||||
call_kwargs = mock_agenerate.call_args.kwargs
|
||||
assert call_kwargs.get("output_version") == "v2"
|
||||
|
||||
|
||||
def test_different_output_versions() -> None:
|
||||
"""Test that different `output_version` values are handled correctly."""
|
||||
messages = [HumanMessage(content="Hello")]
|
||||
model = MockChatModel(output_version="v0")
|
||||
|
||||
# Test with various output version strings
|
||||
test_versions = ["v0", "v1", "v2", "beta", "experimental", "1.0", "2025-01-01"]
|
||||
|
||||
for version in test_versions:
|
||||
model.invoke(messages, output_version=version)
|
||||
assert model.last_output_version == version
|
||||
|
||||
|
||||
def test_output_version_is_keyword_only() -> None:
|
||||
"""Test that `output_version` parameter is keyword-only in public methods."""
|
||||
messages = [HumanMessage(content="Hello")]
|
||||
model = MockChatModel()
|
||||
|
||||
# These should work (keyword argument)
|
||||
model.invoke(messages, output_version="v1")
|
||||
list(model.stream(messages, output_version="v1"))
|
||||
|
||||
# These should fail if output_version were positional (but they don't because
|
||||
# it's after the * in the signature, making it keyword-only)
|
||||
with pytest.raises(TypeError):
|
||||
model.invoke(messages, None, "v1") # type: ignore[arg-type,misc]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
list(model.stream(messages, None, "v1")) # type: ignore[arg-type,misc]
|
||||
|
||||
|
||||
async def test_async_output_version_is_keyword_only() -> None:
|
||||
"""Test that `output_version` parameter is keyword-only in async public methods."""
|
||||
messages = [HumanMessage(content="Hello")]
|
||||
model = MockChatModel()
|
||||
|
||||
# These should work (keyword argument)
|
||||
await model.ainvoke(messages, output_version="v1")
|
||||
async for _ in model.astream(messages, output_version="v1"):
|
||||
pass
|
||||
|
||||
# These should fail if output_version were positional
|
||||
with pytest.raises(TypeError):
|
||||
await model.ainvoke(messages, None, "v1") # type: ignore[arg-type,misc]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
async for _ in model.astream(messages, None, "v1"): # type: ignore[arg-type,misc]
|
||||
pass
|
||||
|
||||
|
||||
def test_output_version_inheritance() -> None:
|
||||
"""Test that subclasses properly inherit `output_version` functionality."""
|
||||
|
||||
class CustomChatModel(BaseChatModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.received_versions: list[str] = []
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
self.received_versions.append(output_version)
|
||||
message = AIMessage(content="response")
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "custom-model"
|
||||
|
||||
messages = [HumanMessage(content="Hello")]
|
||||
model = CustomChatModel(output_version="default_v1")
|
||||
|
||||
# Test that default is used
|
||||
model.invoke(messages)
|
||||
assert model.received_versions[-1] == "default_v1"
|
||||
|
||||
# Test that override is used
|
||||
model.invoke(messages, output_version="override_v2")
|
||||
assert model.received_versions[-1] == "override_v2"
|
||||
|
||||
|
||||
def test_internal_output_version_parameter_in_signature() -> None:
|
||||
"""Test that internal methods have `output_version` in their signatures."""
|
||||
import inspect
|
||||
|
||||
model = MockChatModel()
|
||||
|
||||
# Check that the internal methods have output_version parameters
|
||||
generate_sig = inspect.signature(model._generate)
|
||||
assert "output_version" in generate_sig.parameters
|
||||
assert generate_sig.parameters["output_version"].default == "v0"
|
||||
|
||||
agenerate_sig = inspect.signature(model._agenerate)
|
||||
assert "output_version" in agenerate_sig.parameters
|
||||
assert agenerate_sig.parameters["output_version"].default == "v0"
|
||||
|
||||
stream_sig = inspect.signature(model._stream)
|
||||
assert "output_version" in stream_sig.parameters
|
||||
assert stream_sig.parameters["output_version"].default == "v0"
|
||||
|
||||
astream_sig = inspect.signature(model._astream)
|
||||
assert "output_version" in astream_sig.parameters
|
||||
assert astream_sig.parameters["output_version"].default == "v0"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_version", "expected"),
|
||||
[
|
||||
("v0", "v0"),
|
||||
("v1", "v1"),
|
||||
("responses/v1", "responses/v1"),
|
||||
],
|
||||
)
|
||||
def test_output_version_stored_in_additional_kwargs_invoke(
|
||||
messages: list[BaseMessage], model_version: str, expected: str
|
||||
) -> None:
|
||||
"""Test that output_version is stored in message additional_kwargs for invoke."""
|
||||
model = MockChatModel(output_version=model_version)
|
||||
response = model.invoke(messages)
|
||||
|
||||
assert "output_version" in response.additional_kwargs
|
||||
assert response.additional_kwargs["output_version"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_version", "override_version", "expected"),
|
||||
[
|
||||
("v0", None, "v0"),
|
||||
("v1", None, "v1"),
|
||||
("v0", "v2", "v2"),
|
||||
("v1", "v0", "v0"),
|
||||
],
|
||||
)
|
||||
async def test_output_version_ainvoke_with_override(
|
||||
messages: list[BaseMessage],
|
||||
model_version: str,
|
||||
override_version: str,
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""Test ainvoke with output_version override."""
|
||||
model = MockChatModel(output_version=model_version)
|
||||
response = await model.ainvoke(messages, output_version=override_version)
|
||||
|
||||
assert "output_version" in response.additional_kwargs
|
||||
assert response.additional_kwargs["output_version"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_version", "override_version", "expected"),
|
||||
[
|
||||
("v0", None, "v0"),
|
||||
("v1", None, "v1"),
|
||||
("v0", "v2", "v2"),
|
||||
],
|
||||
)
|
||||
def test_output_version_stored_in_stream_chunks(
|
||||
messages: list[BaseMessage],
|
||||
model_version: str,
|
||||
override_version: str,
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""Test that output_version is stored in streaming chunk additional_kwargs."""
|
||||
model = MockChatModel(output_version=model_version)
|
||||
chunks = list(model.stream(messages, output_version=override_version))
|
||||
|
||||
for chunk in chunks:
|
||||
assert "output_version" in chunk.additional_kwargs
|
||||
assert chunk.additional_kwargs["output_version"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_version", "override_version", "expected"),
|
||||
[
|
||||
("v0", None, "v0"),
|
||||
("v1", None, "v1"),
|
||||
("v0", "v2", "v2"),
|
||||
],
|
||||
)
|
||||
async def test_output_version_stored_in_astream_chunks(
|
||||
messages: list[BaseMessage],
|
||||
model_version: str,
|
||||
override_version: str,
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""Test that output_version is stored in async streaming chunk additional_kwargs."""
|
||||
model = MockChatModel(output_version=model_version)
|
||||
chunks = [
|
||||
chunk
|
||||
async for chunk in model.astream(messages, output_version=override_version)
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
assert "output_version" in chunk.additional_kwargs
|
||||
assert chunk.additional_kwargs["output_version"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["v0", "v1", "v2", "beta", "responses/v1"])
|
||||
def test_output_version_preserved_through_serialization(
|
||||
messages: list[BaseMessage], version: str
|
||||
) -> None:
|
||||
"""Test that output_version in additional_kwargs persists through serialization."""
|
||||
import json
|
||||
|
||||
model = MockChatModel(output_version="v0")
|
||||
response = model.invoke(messages, output_version=version)
|
||||
assert response.additional_kwargs["output_version"] == version
|
||||
|
||||
# Verify serialization preserves version
|
||||
message_dict = {"additional_kwargs": response.additional_kwargs}
|
||||
serialized = json.dumps(message_dict)
|
||||
deserialized = json.loads(serialized)
|
||||
assert deserialized["additional_kwargs"]["output_version"] == version
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("output_version", "content_type"),
|
||||
[
|
||||
("v0", str),
|
||||
("v1", list),
|
||||
],
|
||||
)
|
||||
def test_output_version_with_different_content_formats(
|
||||
messages: list[BaseMessage], output_version: str, content_type: type
|
||||
) -> None:
|
||||
"""Test output_version storage works with different content formats."""
|
||||
|
||||
class CustomChatModel(BaseChatModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage], # noqa: ARG002
|
||||
stop: Optional[list[str]] = None, # noqa: ARG002
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None, # noqa: ARG002
|
||||
*,
|
||||
output_version: str = "v0",
|
||||
**kwargs: Any, # noqa: ARG002
|
||||
) -> ChatResult:
|
||||
if output_version == "v0":
|
||||
content: Union[str, list[dict[str, Any]]] = "test response"
|
||||
else:
|
||||
content = [{"type": "text", "text": "test response"}]
|
||||
message = AIMessage(content=content) # type: ignore[arg-type]
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "custom-test-model"
|
||||
|
||||
model = CustomChatModel()
|
||||
response = model.invoke(messages, output_version=output_version)
|
||||
|
||||
assert response.additional_kwargs["output_version"] == output_version
|
||||
assert isinstance(response.content, content_type)
|
||||
|
||||
|
||||
def test_output_version_preserves_existing_additional_kwargs(
|
||||
messages: list[BaseMessage],
|
||||
) -> None:
|
||||
"""Test that output_version doesn't overwrite existing additional_kwargs."""
|
||||
|
||||
class ModelWithExistingKwargs(BaseChatModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage], # noqa: ARG002
|
||||
stop: Optional[list[str]] = None, # noqa: ARG002
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None, # noqa: ARG002
|
||||
*,
|
||||
output_version: str = "v0", # noqa: ARG002
|
||||
**kwargs: Any, # noqa: ARG002
|
||||
) -> ChatResult:
|
||||
message = AIMessage(
|
||||
content="test response",
|
||||
additional_kwargs={"model": "test-model", "temperature": 0.7},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "kwargs-test-model"
|
||||
|
||||
model = ModelWithExistingKwargs(output_version="v1")
|
||||
response = model.invoke(messages)
|
||||
|
||||
# Verify output_version was added and existing kwargs preserved
|
||||
assert response.additional_kwargs["output_version"] == "v1"
|
||||
assert response.additional_kwargs["model"] == "test-model"
|
||||
assert response.additional_kwargs["temperature"] == 0.7
|
||||
assert len(response.additional_kwargs) == 3
|
||||
@@ -216,7 +216,7 @@ def test_rate_limit_skips_cache() -> None:
|
||||
'[{"lc": 1, "type": "constructor", "id": ["langchain", "schema", '
|
||||
'"messages", '
|
||||
'"HumanMessage"], "kwargs": {"content": "foo", "type": "human"}}]',
|
||||
"[('_type', 'generic-fake-chat-model'), ('stop', None)]",
|
||||
"[('_output_version', 'v0'), ('_type', 'generic-fake-chat-model'), ('stop', None)]", # noqa: E501
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def test_convert_to_v1_from_responses() -> None:
|
||||
message = AIMessage(
|
||||
[
|
||||
{"type": "reasoning", "id": "abc123", "summary": []},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"summary": [
|
||||
{"type": "summary_text", "text": "foo bar"},
|
||||
{"type": "summary_text", "text": "baz"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": "call_123",
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
},
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": "call_234",
|
||||
"name": "get_weather_2",
|
||||
"arguments": '{"location": "New York"}',
|
||||
"id": "fc_123",
|
||||
},
|
||||
{"type": "text", "text": "Hello "},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "world",
|
||||
"annotations": [
|
||||
{"type": "url_citation", "url": "https://example.com"},
|
||||
{
|
||||
"type": "file_citation",
|
||||
"filename": "my doc",
|
||||
"index": 1,
|
||||
"file_id": "file_123",
|
||||
},
|
||||
{"bar": "baz"},
|
||||
],
|
||||
},
|
||||
{"type": "image_generation_call", "id": "ig_123", "result": "..."},
|
||||
{"type": "something_else", "foo": "bar"},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"args": {"location": "San Francisco"},
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_234",
|
||||
"name": "get_weather_2",
|
||||
"args": {"location": "New York"},
|
||||
},
|
||||
],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
)
|
||||
expected_content: list[types.ContentBlock] = [
|
||||
{"type": "reasoning", "id": "abc123"},
|
||||
{"type": "reasoning", "id": "abc234", "reasoning": "foo bar"},
|
||||
{"type": "reasoning", "id": "abc234", "reasoning": "baz"},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"args": {"location": "San Francisco"},
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_234",
|
||||
"name": "get_weather_2",
|
||||
"args": {"location": "New York"},
|
||||
"extras": {"item_id": "fc_123"},
|
||||
},
|
||||
{"type": "text", "text": "Hello "},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "world",
|
||||
"annotations": [
|
||||
{"type": "citation", "url": "https://example.com"},
|
||||
{
|
||||
"type": "citation",
|
||||
"title": "my doc",
|
||||
"extras": {"file_id": "file_123", "index": 1},
|
||||
},
|
||||
{"type": "non_standard_annotation", "value": {"bar": "baz"}},
|
||||
],
|
||||
},
|
||||
{"type": "image", "base64": "...", "id": "ig_123"},
|
||||
{
|
||||
"type": "non_standard",
|
||||
"value": {"type": "something_else", "foo": "bar"},
|
||||
},
|
||||
]
|
||||
assert message.content_blocks == expected_content
|
||||
|
||||
# Check no mutation
|
||||
assert message.content != expected_content
|
||||
|
||||
|
||||
def test_convert_to_v1_from_responses_chunk() -> None:
|
||||
chunks = [
|
||||
AIMessageChunk(
|
||||
content=[{"type": "reasoning", "id": "abc123", "summary": [], "index": 0}],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"summary": [
|
||||
{"type": "summary_text", "text": "foo ", "index": 0},
|
||||
],
|
||||
"index": 1,
|
||||
}
|
||||
],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"summary": [
|
||||
{"type": "summary_text", "text": "bar", "index": 0},
|
||||
],
|
||||
"index": 1,
|
||||
}
|
||||
],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"summary": [
|
||||
{"type": "summary_text", "text": "baz", "index": 1},
|
||||
],
|
||||
"index": 1,
|
||||
}
|
||||
],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
),
|
||||
]
|
||||
expected_chunks = [
|
||||
AIMessageChunk(
|
||||
content=[{"type": "reasoning", "id": "abc123", "index": "lc_rs_305f30"}],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"reasoning": "foo ",
|
||||
"index": "lc_rs_315f30",
|
||||
}
|
||||
],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"reasoning": "bar",
|
||||
"index": "lc_rs_315f30",
|
||||
}
|
||||
],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"reasoning": "baz",
|
||||
"index": "lc_rs_315f31",
|
||||
}
|
||||
],
|
||||
response_metadata={"model_provider": "openai"},
|
||||
),
|
||||
]
|
||||
for chunk, expected in zip(chunks, expected_chunks):
|
||||
assert chunk.content_blocks == expected.content_blocks
|
||||
|
||||
full: Optional[AIMessageChunk] = None
|
||||
for chunk in chunks:
|
||||
full = chunk if full is None else full + chunk # type: ignore[assignment]
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
|
||||
expected_content = [
|
||||
{"type": "reasoning", "id": "abc123", "summary": [], "index": 0},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"summary": [
|
||||
{"type": "summary_text", "text": "foo bar", "index": 0},
|
||||
{"type": "summary_text", "text": "baz", "index": 1},
|
||||
],
|
||||
"index": 1,
|
||||
},
|
||||
]
|
||||
assert full.content == expected_content
|
||||
|
||||
expected_content_blocks = [
|
||||
{"type": "reasoning", "id": "abc123", "index": "lc_rs_305f30"},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"reasoning": "foo bar",
|
||||
"index": "lc_rs_315f30",
|
||||
},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"reasoning": "baz",
|
||||
"index": "lc_rs_315f31",
|
||||
},
|
||||
]
|
||||
assert full.content_blocks == expected_content_blocks
|
||||
@@ -1,5 +1,6 @@
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.ai import (
|
||||
InputTokenDetails,
|
||||
OutputTokenDetails,
|
||||
@@ -196,3 +197,116 @@ def test_add_ai_message_chunks_usage() -> None:
|
||||
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_content_blocks() -> None:
|
||||
message = AIMessage(
|
||||
"",
|
||||
tool_calls=[
|
||||
{"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}
|
||||
],
|
||||
)
|
||||
assert len(message.content_blocks) == 1
|
||||
assert message.content_blocks[0]["type"] == "tool_call"
|
||||
assert message.content_blocks == [
|
||||
{"type": "tool_call", "id": "abc_123", "name": "foo", "args": {"a": "b"}}
|
||||
]
|
||||
assert message.content == ""
|
||||
|
||||
message = AIMessage(
|
||||
"foo",
|
||||
tool_calls=[
|
||||
{"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}
|
||||
],
|
||||
)
|
||||
assert len(message.content_blocks) == 2
|
||||
assert message.content_blocks[0]["type"] == "text"
|
||||
assert message.content_blocks[1]["type"] == "tool_call"
|
||||
assert message.content_blocks == [
|
||||
{"type": "text", "text": "foo"},
|
||||
{"type": "tool_call", "id": "abc_123", "name": "foo", "args": {"a": "b"}},
|
||||
]
|
||||
assert message.content == "foo"
|
||||
|
||||
# With standard blocks
|
||||
standard_content: list[types.ContentBlock] = [
|
||||
{"type": "reasoning", "reasoning": "foo"},
|
||||
{"type": "text", "text": "bar"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "baz",
|
||||
"annotations": [{"type": "citation", "url": "http://example.com"}],
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "http://example.com/image.png",
|
||||
"extras": {"foo": "bar"},
|
||||
},
|
||||
{
|
||||
"type": "non_standard",
|
||||
"value": {"custom_key": "custom_value", "another_key": 123},
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": "foo",
|
||||
"args": {"a": "b"},
|
||||
"id": "abc_123",
|
||||
},
|
||||
]
|
||||
missing_tool_call: types.ToolCall = {
|
||||
"type": "tool_call",
|
||||
"name": "bar",
|
||||
"args": {"c": "d"},
|
||||
"id": "abc_234",
|
||||
}
|
||||
message = AIMessage(
|
||||
content_blocks=standard_content,
|
||||
tool_calls=[
|
||||
{"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"},
|
||||
missing_tool_call,
|
||||
],
|
||||
)
|
||||
assert message.content_blocks == [*standard_content, missing_tool_call]
|
||||
|
||||
# Check we auto-populate tool_calls
|
||||
standard_content = [
|
||||
{"type": "text", "text": "foo"},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": "foo",
|
||||
"args": {"a": "b"},
|
||||
"id": "abc_123",
|
||||
},
|
||||
missing_tool_call,
|
||||
]
|
||||
message = AIMessage(content_blocks=standard_content)
|
||||
assert message.tool_calls == [
|
||||
{"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"},
|
||||
missing_tool_call,
|
||||
]
|
||||
|
||||
# Chunks
|
||||
message = AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"name": "foo",
|
||||
"args": "",
|
||||
"id": "abc_123",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
)
|
||||
assert len(message.content_blocks) == 1
|
||||
assert message.content_blocks[0]["type"] == "tool_call_chunk"
|
||||
assert message.content_blocks == [
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"name": "foo",
|
||||
"args": "",
|
||||
"id": "abc_123",
|
||||
"index": 0,
|
||||
}
|
||||
]
|
||||
assert message.content == ""
|
||||
|
||||
@@ -5,26 +5,51 @@ EXPECTED_ALL = [
|
||||
"_message_from_dict",
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"Annotation",
|
||||
"AnyMessage",
|
||||
"AudioContentBlock",
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"ContentBlock",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"Citation",
|
||||
"CodeInterpreterCall",
|
||||
"CodeInterpreterOutput",
|
||||
"CodeInterpreterResult",
|
||||
"DataContentBlock",
|
||||
"FileContentBlock",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"ImageContentBlock",
|
||||
"InvalidToolCall",
|
||||
"LC_AUTO_PREFIX",
|
||||
"LC_ID_PREFIX",
|
||||
"NonStandardAnnotation",
|
||||
"NonStandardContentBlock",
|
||||
"PlainTextContentBlock",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"TextContentBlock",
|
||||
"ToolCall",
|
||||
"ToolCallChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"VideoContentBlock",
|
||||
"WebSearchCall",
|
||||
"WebSearchResult",
|
||||
"ReasoningContentBlock",
|
||||
"RemoveMessage",
|
||||
"convert_to_messages",
|
||||
"ensure_id",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
"is_reasoning_block",
|
||||
"is_text_block",
|
||||
"is_tool_call_block",
|
||||
"is_tool_call_chunk",
|
||||
"merge_content",
|
||||
"message_chunk_to_message",
|
||||
"message_to_dict",
|
||||
|
||||
@@ -1215,13 +1215,14 @@ def test_convert_to_openai_messages_developer() -> None:
|
||||
|
||||
|
||||
def test_convert_to_openai_messages_multimodal() -> None:
|
||||
"""v0 and v1 content to OpenAI messages conversion."""
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
# Prior v0 blocks
|
||||
{"type": "text", "text": "Text message"},
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "url",
|
||||
"url": "https://example.com/test.png",
|
||||
},
|
||||
{
|
||||
@@ -1238,6 +1239,7 @@ def test_convert_to_openai_messages_multimodal() -> None:
|
||||
"filename": "test.pdf",
|
||||
},
|
||||
{
|
||||
# OpenAI Chat Completions file format
|
||||
"type": "file",
|
||||
"file": {
|
||||
"filename": "draconomicon.pdf",
|
||||
@@ -1262,22 +1264,47 @@ def test_convert_to_openai_messages_multimodal() -> None:
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
# v1 Additions
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "url", # backward compatibility v0 block field
|
||||
"url": "https://example.com/test.png",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"base64": "<base64 string>",
|
||||
"mime_type": "image/png",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"base64": "<base64 string>",
|
||||
"mime_type": "application/pdf",
|
||||
"filename": "test.pdf", # backward compatibility v0 block field
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file_id": "file-abc123",
|
||||
},
|
||||
{
|
||||
"type": "audio",
|
||||
"base64": "<base64 string>",
|
||||
"mime_type": "audio/wav",
|
||||
},
|
||||
]
|
||||
)
|
||||
]
|
||||
result = convert_to_openai_messages(messages, text_format="block")
|
||||
assert len(result) == 1
|
||||
message = result[0]
|
||||
assert len(message["content"]) == 8
|
||||
assert len(message["content"]) == 13
|
||||
|
||||
# Test adding filename
|
||||
# Test auto-adding filename
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 string>",
|
||||
"base64": "<base64 string>",
|
||||
"mime_type": "application/pdf",
|
||||
},
|
||||
]
|
||||
@@ -1290,6 +1317,7 @@ def test_convert_to_openai_messages_multimodal() -> None:
|
||||
assert len(message["content"]) == 1
|
||||
block = message["content"][0]
|
||||
assert block == {
|
||||
# OpenAI Chat Completions file format
|
||||
"type": "file",
|
||||
"file": {
|
||||
"file_data": "data:application/pdf;base64,<base64 string>",
|
||||
|
||||
@@ -726,7 +726,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -752,6 +752,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -763,6 +767,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -781,9 +796,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -998,12 +1014,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -1015,6 +1042,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -1026,9 +1064,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -2158,7 +2197,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -2184,6 +2223,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2195,6 +2238,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2213,9 +2267,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -2430,12 +2485,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2447,6 +2513,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -2458,9 +2535,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
|
||||
@@ -1129,7 +1129,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -1155,6 +1155,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -1166,6 +1170,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -1184,9 +1199,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -1401,12 +1417,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -1418,6 +1445,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -1429,9 +1467,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
|
||||
@@ -2674,7 +2674,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -2700,6 +2700,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2711,6 +2715,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2728,9 +2743,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -2943,12 +2959,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2960,6 +2987,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -2970,9 +3008,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -4150,7 +4189,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -4176,6 +4215,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -4187,6 +4230,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -4204,9 +4258,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -4438,12 +4493,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -4455,6 +4521,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -4465,9 +4542,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -5657,7 +5735,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -5683,6 +5761,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -5694,6 +5776,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -5711,9 +5804,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -5945,12 +6039,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -5962,6 +6067,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -5972,9 +6088,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -7039,7 +7156,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -7065,6 +7182,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -7076,6 +7197,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -7093,9 +7225,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -7308,12 +7441,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -7325,6 +7469,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -7335,9 +7490,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -8557,7 +8713,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -8583,6 +8739,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -8594,6 +8754,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -8611,9 +8782,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -8845,12 +9017,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -8862,6 +9045,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -8872,9 +9066,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -9984,7 +10179,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -10010,6 +10205,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -10021,6 +10220,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -10038,9 +10248,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -10253,12 +10464,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -10270,6 +10492,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -10280,9 +10513,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -11410,7 +11644,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -11436,6 +11670,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -11447,6 +11685,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -11464,9 +11713,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -11709,12 +11959,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -11726,6 +11987,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -11736,9 +12008,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -12878,7 +13151,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -12904,6 +13177,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -12915,6 +13192,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -12932,9 +13220,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -13166,12 +13455,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -13183,6 +13483,17 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'type': 'integer',
|
||||
}),
|
||||
dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
]),
|
||||
'title': 'Index',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -13193,9 +13504,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
|
||||
@@ -29,6 +29,11 @@ def _any_id_document(**kwargs: Any) -> Document:
|
||||
|
||||
def _any_id_ai_message(**kwargs: Any) -> AIMessage:
|
||||
"""Create ai message with an any id field."""
|
||||
# Add default output_version if not specified and no additional_kwargs provided
|
||||
if "additional_kwargs" not in kwargs:
|
||||
kwargs["additional_kwargs"] = {"output_version": "v0"}
|
||||
elif "output_version" not in kwargs.get("additional_kwargs", {}):
|
||||
kwargs["additional_kwargs"]["output_version"] = "v0"
|
||||
message = AIMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
@@ -36,6 +41,11 @@ def _any_id_ai_message(**kwargs: Any) -> AIMessage:
|
||||
|
||||
def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk:
|
||||
"""Create ai message with an any id field."""
|
||||
# Add default output_version if not specified and no additional_kwargs provided
|
||||
if "additional_kwargs" not in kwargs:
|
||||
kwargs["additional_kwargs"] = {"output_version": "v0"}
|
||||
elif "output_version" not in kwargs.get("additional_kwargs", {}):
|
||||
kwargs["additional_kwargs"]["output_version"] = "v0"
|
||||
message = AIMessageChunk(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
@@ -3,6 +3,7 @@ import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import dumpd, load
|
||||
@@ -30,6 +31,7 @@ from langchain_core.messages import (
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.messages.content import KNOWN_BLOCK_TYPES, ContentBlock
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
@@ -178,21 +180,23 @@ def test_message_chunks() -> None:
|
||||
assert AIMessageChunk(content="") + left == left
|
||||
assert right + AIMessageChunk(content="") == right
|
||||
|
||||
default_id = "lc_run--abc123"
|
||||
meaningful_id = "msg_def456"
|
||||
|
||||
# Test ID order of precedence
|
||||
null_id = AIMessageChunk(content="", id=None)
|
||||
default_id = AIMessageChunk(
|
||||
content="", id="run-abc123"
|
||||
null_id_chunk = AIMessageChunk(content="", id=None)
|
||||
default_id_chunk = AIMessageChunk(
|
||||
content="", id=default_id
|
||||
) # LangChain-assigned run ID
|
||||
meaningful_id = AIMessageChunk(content="", id="msg_def456") # provider-assigned ID
|
||||
provider_chunk = AIMessageChunk(
|
||||
content="", id=meaningful_id
|
||||
) # provided ID (either by user or provider)
|
||||
|
||||
assert (null_id + default_id).id == "run-abc123"
|
||||
assert (default_id + null_id).id == "run-abc123"
|
||||
assert (null_id_chunk + default_id_chunk).id == default_id
|
||||
assert (null_id_chunk + provider_chunk).id == meaningful_id
|
||||
|
||||
assert (null_id + meaningful_id).id == "msg_def456"
|
||||
assert (meaningful_id + null_id).id == "msg_def456"
|
||||
|
||||
assert (default_id + meaningful_id).id == "msg_def456"
|
||||
assert (meaningful_id + default_id).id == "msg_def456"
|
||||
# Provider assigned IDs have highest precedence
|
||||
assert (default_id_chunk + provider_chunk).id == meaningful_id
|
||||
|
||||
|
||||
def test_chat_message_chunks() -> None:
|
||||
@@ -207,7 +211,7 @@ def test_chat_message_chunks() -> None:
|
||||
):
|
||||
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||
role="Assistant", content=" indeed."
|
||||
)
|
||||
) # type: ignore[reportUnusedExpression, unused-ignore]
|
||||
|
||||
assert ChatMessageChunk(role="User", content="I am") + AIMessageChunk(
|
||||
content=" indeed."
|
||||
@@ -316,7 +320,7 @@ def test_function_message_chunks() -> None:
|
||||
):
|
||||
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||
name="bye", content=" indeed."
|
||||
)
|
||||
) # type: ignore[reportUnusedExpression, unused-ignore]
|
||||
|
||||
|
||||
def test_ai_message_chunks() -> None:
|
||||
@@ -332,7 +336,7 @@ def test_ai_message_chunks() -> None:
|
||||
):
|
||||
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||
example=False, content=" indeed."
|
||||
)
|
||||
) # type: ignore[reportUnusedExpression, unused-ignore]
|
||||
|
||||
|
||||
class TestGetBufferString(unittest.TestCase):
|
||||
@@ -1038,12 +1042,13 @@ def test_tool_message_content() -> None:
|
||||
ToolMessage(["foo"], tool_call_id="1")
|
||||
ToolMessage([{"foo": "bar"}], tool_call_id="1")
|
||||
|
||||
assert ToolMessage(("a", "b", "c"), tool_call_id="1").content == ["a", "b", "c"] # type: ignore[arg-type]
|
||||
assert ToolMessage(5, tool_call_id="1").content == "5" # type: ignore[arg-type]
|
||||
assert ToolMessage(5.1, tool_call_id="1").content == "5.1" # type: ignore[arg-type]
|
||||
assert ToolMessage({"foo": "bar"}, tool_call_id="1").content == "{'foo': 'bar'}" # type: ignore[arg-type]
|
||||
# Ignoring since we're testing that tuples get converted to lists in `coerce_args`
|
||||
assert ToolMessage(("a", "b", "c"), tool_call_id="1").content == ["a", "b", "c"] # type: ignore[call-overload]
|
||||
assert ToolMessage(5, tool_call_id="1").content == "5" # type: ignore[call-overload]
|
||||
assert ToolMessage(5.1, tool_call_id="1").content == "5.1" # type: ignore[call-overload]
|
||||
assert ToolMessage({"foo": "bar"}, tool_call_id="1").content == "{'foo': 'bar'}" # type: ignore[call-overload]
|
||||
assert (
|
||||
ToolMessage(Document("foo"), tool_call_id="1").content == "page_content='foo'" # type: ignore[arg-type]
|
||||
ToolMessage(Document("foo"), tool_call_id="1").content == "page_content='foo'" # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
|
||||
@@ -1113,26 +1118,45 @@ def test_message_text() -> None:
|
||||
|
||||
|
||||
def test_is_data_content_block() -> None:
|
||||
# Test all DataContentBlock types with various data fields
|
||||
|
||||
# Image blocks
|
||||
assert is_data_content_block({"type": "image", "url": "https://..."})
|
||||
assert is_data_content_block(
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "url",
|
||||
"url": "https://...",
|
||||
}
|
||||
{"type": "image", "base64": "<base64 data>", "mime_type": "image/jpeg"}
|
||||
)
|
||||
|
||||
# Video blocks
|
||||
assert is_data_content_block({"type": "video", "url": "https://video.mp4"})
|
||||
assert is_data_content_block(
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 data>",
|
||||
"mime_type": "image/jpeg",
|
||||
}
|
||||
{"type": "video", "base64": "<base64 video>", "mime_type": "video/mp4"}
|
||||
)
|
||||
assert is_data_content_block({"type": "video", "file_id": "vid_123"})
|
||||
|
||||
# Audio blocks
|
||||
assert is_data_content_block({"type": "audio", "url": "https://audio.mp3"})
|
||||
assert is_data_content_block(
|
||||
{"type": "audio", "base64": "<base64 audio>", "mime_type": "audio/mp3"}
|
||||
)
|
||||
assert is_data_content_block({"type": "audio", "file_id": "aud_123"})
|
||||
|
||||
# Plain text blocks
|
||||
assert is_data_content_block({"type": "text-plain", "text": "document content"})
|
||||
assert is_data_content_block({"type": "text-plain", "url": "https://doc.txt"})
|
||||
assert is_data_content_block({"type": "text-plain", "file_id": "txt_123"})
|
||||
|
||||
# File blocks
|
||||
assert is_data_content_block({"type": "file", "url": "https://file.pdf"})
|
||||
assert is_data_content_block(
|
||||
{"type": "file", "base64": "<base64 file>", "mime_type": "application/pdf"}
|
||||
)
|
||||
assert is_data_content_block({"type": "file", "file_id": "file_123"})
|
||||
|
||||
# Blocks with additional metadata (should still be valid)
|
||||
assert is_data_content_block(
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 data>",
|
||||
"base64": "<base64 data>",
|
||||
"mime_type": "image/jpeg",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
@@ -1140,65 +1164,145 @@ def test_is_data_content_block() -> None:
|
||||
assert is_data_content_block(
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 data>",
|
||||
"base64": "<base64 data>",
|
||||
"mime_type": "image/jpeg",
|
||||
"metadata": {"cache_control": {"type": "ephemeral"}},
|
||||
}
|
||||
)
|
||||
|
||||
assert not is_data_content_block(
|
||||
assert is_data_content_block(
|
||||
{
|
||||
"type": "text",
|
||||
"text": "foo",
|
||||
"type": "image",
|
||||
"base64": "<base64 data>",
|
||||
"mime_type": "image/jpeg",
|
||||
"extras": "hi",
|
||||
}
|
||||
)
|
||||
|
||||
# Invalid cases - wrong type
|
||||
assert not is_data_content_block({"type": "text", "text": "foo"})
|
||||
assert not is_data_content_block(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://..."},
|
||||
}
|
||||
)
|
||||
assert not is_data_content_block(
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
}
|
||||
)
|
||||
assert not is_data_content_block(
|
||||
{
|
||||
"type": "image",
|
||||
"source": "<base64 data>",
|
||||
}
|
||||
} # This is OpenAI Chat Completions
|
||||
)
|
||||
assert not is_data_content_block({"type": "tool_call", "name": "func", "args": {}})
|
||||
assert not is_data_content_block({"type": "invalid", "url": "something"})
|
||||
|
||||
# Invalid cases - valid type but no data or `source_type` fields
|
||||
assert not is_data_content_block({"type": "image"})
|
||||
assert not is_data_content_block({"type": "video", "mime_type": "video/mp4"})
|
||||
assert not is_data_content_block({"type": "audio", "extras": {"key": "value"}})
|
||||
|
||||
# Invalid cases - valid type but wrong data field name
|
||||
assert not is_data_content_block({"type": "image", "source": "<base64 data>"})
|
||||
assert not is_data_content_block({"type": "video", "data": "video_data"})
|
||||
|
||||
# Edge cases - empty or missing values
|
||||
assert not is_data_content_block({})
|
||||
assert not is_data_content_block({"url": "https://..."}) # missing type
|
||||
|
||||
|
||||
def test_convert_to_openai_image_block() -> None:
|
||||
input_block = {
|
||||
"type": "image",
|
||||
"source_type": "url",
|
||||
"url": "https://...",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
expected = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://..."},
|
||||
}
|
||||
result = convert_to_openai_image_block(input_block)
|
||||
assert result == expected
|
||||
|
||||
input_block = {
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 data>",
|
||||
"mime_type": "image/jpeg",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
expected = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,<base64 data>",
|
||||
for input_block in [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://...",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "url",
|
||||
"url": "https://...",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
]:
|
||||
expected = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://..."},
|
||||
}
|
||||
result = convert_to_openai_image_block(input_block)
|
||||
assert result == expected
|
||||
|
||||
for input_block in [
|
||||
{
|
||||
"type": "image",
|
||||
"base64": "<base64 data>",
|
||||
"mime_type": "image/jpeg",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 data>",
|
||||
"mime_type": "image/jpeg",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
]:
|
||||
expected = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,<base64 data>",
|
||||
},
|
||||
}
|
||||
result = convert_to_openai_image_block(input_block)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_known_block_types() -> None:
|
||||
expected = {
|
||||
bt
|
||||
for bt in get_args(ContentBlock)
|
||||
for bt in get_args(bt.__annotations__["type"])
|
||||
}
|
||||
result = convert_to_openai_image_block(input_block)
|
||||
assert result == expected
|
||||
# Normalize any Literal[...] types in block types to their string values.
|
||||
# This ensures all entries are plain strings, not Literal objects.
|
||||
expected = {
|
||||
t
|
||||
if isinstance(t, str)
|
||||
else t.__args__[0]
|
||||
if hasattr(t, "__args__") and len(t.__args__) == 1
|
||||
else t
|
||||
for t in expected
|
||||
}
|
||||
assert expected == KNOWN_BLOCK_TYPES
|
||||
|
||||
|
||||
def test_typed_init() -> None:
|
||||
ai_message = AIMessage(content_blocks=[{"type": "text", "text": "Hello"}])
|
||||
assert ai_message.content == [{"type": "text", "text": "Hello"}]
|
||||
assert ai_message.content_blocks == ai_message.content
|
||||
|
||||
human_message = HumanMessage(content_blocks=[{"type": "text", "text": "Hello"}])
|
||||
assert human_message.content == [{"type": "text", "text": "Hello"}]
|
||||
assert human_message.content_blocks == human_message.content
|
||||
|
||||
system_message = SystemMessage(content_blocks=[{"type": "text", "text": "Hello"}])
|
||||
assert system_message.content == [{"type": "text", "text": "Hello"}]
|
||||
assert system_message.content_blocks == system_message.content
|
||||
|
||||
tool_message = ToolMessage(
|
||||
content_blocks=[{"type": "text", "text": "Hello"}],
|
||||
tool_call_id="abc123",
|
||||
)
|
||||
assert tool_message.content == [{"type": "text", "text": "Hello"}]
|
||||
assert tool_message.content_blocks == tool_message.content
|
||||
|
||||
for message_class in [AIMessage, HumanMessage, SystemMessage]:
|
||||
message = message_class("Hello")
|
||||
assert message.content == "Hello"
|
||||
assert message.content_blocks == [{"type": "text", "text": "Hello"}]
|
||||
|
||||
message = message_class(content="Hello")
|
||||
assert message.content == "Hello"
|
||||
assert message.content_blocks == [{"type": "text", "text": "Hello"}]
|
||||
|
||||
# Test we get type errors for malformed blocks (type checker will complain if
|
||||
# below type-ignores are unused).
|
||||
_ = AIMessage(content_blocks=[{"type": "text", "bad": "Hello"}]) # type: ignore[list-item]
|
||||
_ = HumanMessage(content_blocks=[{"type": "text", "bad": "Hello"}]) # type: ignore[list-item]
|
||||
_ = SystemMessage(content_blocks=[{"type": "text", "bad": "Hello"}]) # type: ignore[list-item]
|
||||
_ = ToolMessage(
|
||||
content_blocks=[{"type": "text", "bad": "Hello"}], # type: ignore[list-item]
|
||||
tool_call_id="abc123",
|
||||
)
|
||||
|
||||
@@ -2281,7 +2281,7 @@ def test_tool_injected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
|
||||
"""Foo."""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[call-overload]
|
||||
|
||||
assert foo.invoke(
|
||||
{
|
||||
@@ -2290,7 +2290,7 @@ def test_tool_injected_tool_call_id() -> None:
|
||||
"name": "foo",
|
||||
"id": "bar",
|
||||
}
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore[arg-type]
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore[call-overload]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
@@ -2302,7 +2302,7 @@ def test_tool_injected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
|
||||
"""Foo."""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[call-overload]
|
||||
|
||||
assert foo2.invoke(
|
||||
{
|
||||
@@ -2311,14 +2311,14 @@ def test_tool_injected_tool_call_id() -> None:
|
||||
"name": "foo",
|
||||
"id": "bar",
|
||||
}
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore[arg-type]
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore[call-overload]
|
||||
|
||||
|
||||
def test_tool_uninjected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: str) -> ToolMessage:
|
||||
"""Foo."""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[call-overload]
|
||||
|
||||
with pytest.raises(ValueError, match="1 validation error for foo"):
|
||||
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
|
||||
@@ -2330,7 +2330,7 @@ def test_tool_uninjected_tool_call_id() -> None:
|
||||
"name": "foo",
|
||||
"id": "bar",
|
||||
}
|
||||
) == ToolMessage(0, tool_call_id="zap") # type: ignore[arg-type]
|
||||
) == ToolMessage(0, tool_call_id="zap") # type: ignore[call-overload]
|
||||
|
||||
|
||||
def test_tool_return_output_mixin() -> None:
|
||||
|
||||
@@ -47,7 +47,12 @@ def parse_ai_message_to_tool_action(
|
||||
try:
|
||||
args = json.loads(function["arguments"] or "{}")
|
||||
tool_calls.append(
|
||||
ToolCall(name=function_name, args=args, id=tool_call["id"]),
|
||||
ToolCall(
|
||||
type="tool_call",
|
||||
name=function_name,
|
||||
args=args,
|
||||
id=tool_call["id"],
|
||||
),
|
||||
)
|
||||
except JSONDecodeError as e:
|
||||
msg = (
|
||||
|
||||
@@ -258,10 +258,11 @@ def test_configurable_with_default() -> None:
|
||||
"name": None,
|
||||
"bound": {
|
||||
"name": None,
|
||||
"output_version": "v0",
|
||||
"disable_streaming": False,
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"mcp_servers": None,
|
||||
"max_tokens": 1024,
|
||||
"max_tokens": 64000,
|
||||
"temperature": None,
|
||||
"thinking": None,
|
||||
"top_k": None,
|
||||
@@ -277,6 +278,7 @@ def test_configurable_with_default() -> None:
|
||||
"model_kwargs": {},
|
||||
"streaming": False,
|
||||
"stream_usage": True,
|
||||
"output_version": "v0",
|
||||
},
|
||||
"kwargs": {
|
||||
"tools": [{"name": "foo", "description": "foo", "input_schema": {}}],
|
||||
|
||||
@@ -258,10 +258,11 @@ def test_configurable_with_default() -> None:
|
||||
"name": None,
|
||||
"bound": {
|
||||
"name": None,
|
||||
"output_version": "v0",
|
||||
"disable_streaming": False,
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"mcp_servers": None,
|
||||
"max_tokens": 1024,
|
||||
"max_tokens": 64000,
|
||||
"temperature": None,
|
||||
"thinking": None,
|
||||
"top_k": None,
|
||||
@@ -277,6 +278,7 @@ def test_configurable_with_default() -> None:
|
||||
"model_kwargs": {},
|
||||
"streaming": False,
|
||||
"stream_usage": True,
|
||||
"output_version": "v0",
|
||||
},
|
||||
"kwargs": {
|
||||
"tools": [{"name": "foo", "description": "foo", "input_schema": {}}],
|
||||
|
||||
@@ -7,7 +7,7 @@ import warnings
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from operator import itemgetter
|
||||
from typing import Any, Callable, Literal, Optional, Union, cast
|
||||
from typing import Any, Callable, Final, Literal, Optional, Union, cast
|
||||
|
||||
import anthropic
|
||||
from langchain_core._api import beta, deprecated
|
||||
@@ -61,6 +61,32 @@ _message_type_lookups = {
|
||||
}
|
||||
|
||||
|
||||
_MODEL_DEFAULT_MAX_OUTPUT_TOKENS: Final[dict[str, int]] = {
|
||||
"claude-opus-4-1": 32000,
|
||||
"claude-opus-4": 32000,
|
||||
"claude-sonnet-4": 64000,
|
||||
"claude-3-7-sonnet": 64000,
|
||||
"claude-3-5-sonnet": 8192,
|
||||
"claude-3-5-haiku": 8192,
|
||||
"claude-3-haiku": 4096,
|
||||
}
|
||||
_FALLBACK_MAX_OUTPUT_TOKENS: Final[int] = 4096
|
||||
|
||||
|
||||
def _default_max_tokens_for(model: str | None) -> int:
|
||||
"""Return the default max output tokens for an Anthropic model (with fallback).
|
||||
|
||||
Can find the Max Tokens limits here: https://docs.anthropic.com/en/docs/about-claude/models/overview#model-comparison-table
|
||||
"""
|
||||
if not model:
|
||||
return _FALLBACK_MAX_OUTPUT_TOKENS
|
||||
|
||||
parts = model.split("-")
|
||||
family = "-".join(parts[:-1]) if len(parts) > 1 else model
|
||||
|
||||
return _MODEL_DEFAULT_MAX_OUTPUT_TOKENS.get(family, _FALLBACK_MAX_OUTPUT_TOKENS)
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
"""Anthropic tool definition."""
|
||||
|
||||
@@ -1229,7 +1255,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
model: str = Field(alias="model_name")
|
||||
"""Model name to use."""
|
||||
|
||||
max_tokens: int = Field(default=1024, alias="max_tokens_to_sample")
|
||||
max_tokens: Optional[int] = Field(default=None, alias="max_tokens_to_sample")
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
@@ -1367,6 +1393,15 @@ class ChatAnthropic(BaseChatModel):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_default_max_tokens(cls, values: dict[str, Any]) -> Any:
|
||||
"""Set default max_tokens."""
|
||||
if values.get("max_tokens") is None:
|
||||
model = values.get("model") or values.get("model_name")
|
||||
values["max_tokens"] = _default_max_tokens_for(model)
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: dict) -> Any:
|
||||
|
||||
Binary file not shown.
@@ -901,7 +901,10 @@ def test_image_tool_calling() -> None:
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_web_search() -> None:
|
||||
llm = ChatAnthropic(model="claude-3-5-sonnet-latest") # type: ignore[call-arg]
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-5-sonnet-latest", # type: ignore[call-arg]
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
tool = {"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
|
||||
llm_with_tools = llm.bind_tools([tool])
|
||||
|
||||
@@ -111,6 +111,45 @@ def test_anthropic_proxy_from_environment() -> None:
|
||||
assert llm.anthropic_proxy == explicit_proxy
|
||||
|
||||
|
||||
def test_set_default_max_tokens() -> None:
|
||||
"""Test the set_default_max_tokens function."""
|
||||
# Test claude-opus-4 models
|
||||
llm = ChatAnthropic(model="claude-opus-4-20250514", anthropic_api_key="test")
|
||||
assert llm.max_tokens == 32000
|
||||
|
||||
# Test claude-sonnet-4 models
|
||||
llm = ChatAnthropic(model="claude-sonnet-4-latest", anthropic_api_key="test")
|
||||
assert llm.max_tokens == 64000
|
||||
|
||||
# Test claude-3-7-sonnet models
|
||||
llm = ChatAnthropic(model="claude-3-7-sonnet-latest", anthropic_api_key="test")
|
||||
assert llm.max_tokens == 64000
|
||||
|
||||
# Test claude-3-5-sonnet models
|
||||
llm = ChatAnthropic(model="claude-3-5-sonnet-latest", anthropic_api_key="test")
|
||||
assert llm.max_tokens == 8192
|
||||
|
||||
# Test claude-3-5-haiku models
|
||||
llm = ChatAnthropic(model="claude-3-5-haiku-latest", anthropic_api_key="test")
|
||||
assert llm.max_tokens == 8192
|
||||
|
||||
# Test claude-3-haiku models (should default to 4096)
|
||||
llm = ChatAnthropic(model="claude-3-haiku-latest", anthropic_api_key="test")
|
||||
assert llm.max_tokens == 4096
|
||||
|
||||
# Test that existing max_tokens values are preserved
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-5-sonnet-latest", max_tokens=2048, anthropic_api_key="test"
|
||||
)
|
||||
assert llm.max_tokens == 2048
|
||||
|
||||
# Test that explicitly set max_tokens values are preserved
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-5-sonnet-latest", max_tokens=4096, anthropic_api_key="test"
|
||||
)
|
||||
assert llm.max_tokens == 4096
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_name_param() -> None:
|
||||
llm = ChatAnthropic(model_name="foo") # type: ignore[call-arg, call-arg]
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""
|
||||
This module converts between AIMessage output formats for the Responses API.
|
||||
This module converts between AIMessage output formats, which are governed by the
|
||||
``output_version`` attribute on ChatOpenAI. Supported values are ``"v0"`` and
|
||||
``"responses/v1"``.
|
||||
|
||||
ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs:
|
||||
``"v0"`` corresponds to the format as of ChatOpenAI v0.3. For the Responses API, it
|
||||
stores reasoning and tool outputs in AIMessage.additional_kwargs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -28,8 +31,9 @@ ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs
|
||||
id="msg_123",
|
||||
)
|
||||
|
||||
To retain information about response item sequencing (and to accommodate multiple
|
||||
reasoning items), ChatOpenAI now stores these items in the content sequence:
|
||||
``"responses/v1"`` is only applicable to the Responses API. It retains information
|
||||
about response item sequencing and accommodates multiple reasoning items by
|
||||
representing these items in the content sequence:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -57,18 +61,20 @@ There are other, small improvements as well-- e.g., we store message IDs on text
|
||||
content blocks, rather than on the AIMessage.id, which now stores the response ID.
|
||||
|
||||
For backwards compatibility, this module provides functions to convert between the
|
||||
old and new formats. The functions are used internally by ChatOpenAI.
|
||||
|
||||
formats. The functions are used internally by ChatOpenAI.
|
||||
""" # noqa: E501
|
||||
|
||||
import json
|
||||
from typing import Union
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Literal, Union, cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, is_data_content_block
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||
|
||||
|
||||
# v0.3 / Responses
|
||||
def _convert_to_v03_ai_message(
|
||||
message: AIMessage, has_reasoning: bool = False
|
||||
) -> AIMessage:
|
||||
@@ -253,3 +259,241 @@ def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage:
|
||||
},
|
||||
deep=False,
|
||||
)
|
||||
|
||||
|
||||
# v1 / Chat Completions
|
||||
def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage:
|
||||
"""Convert a v1 message to the Chat Completions format."""
|
||||
if isinstance(message.content, list):
|
||||
new_content: list = []
|
||||
for block in message.content:
|
||||
if isinstance(block, dict):
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
# Strip annotations
|
||||
new_content.append({"type": "text", "text": block["text"]})
|
||||
elif block_type in ("reasoning", "tool_call"):
|
||||
pass
|
||||
else:
|
||||
new_content.append(block)
|
||||
else:
|
||||
new_content.append(block)
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
return message
|
||||
|
||||
|
||||
# v1 / Responses
|
||||
def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
|
||||
if annotation["type"] == "citation":
|
||||
new_ann: dict[str, Any] = {}
|
||||
for field in ("end_index", "start_index"):
|
||||
if field in annotation:
|
||||
new_ann[field] = annotation[field]
|
||||
|
||||
if "url" in annotation:
|
||||
# URL citation
|
||||
if "title" in annotation:
|
||||
new_ann["title"] = annotation["title"]
|
||||
new_ann["type"] = "url_citation"
|
||||
new_ann["url"] = annotation["url"]
|
||||
else:
|
||||
# Document citation
|
||||
new_ann["type"] = "file_citation"
|
||||
if "title" in annotation:
|
||||
new_ann["filename"] = annotation["title"]
|
||||
|
||||
if extra_fields := annotation.get("extras"):
|
||||
for field, value in extra_fields.items():
|
||||
new_ann[field] = value
|
||||
|
||||
return new_ann
|
||||
|
||||
elif annotation["type"] == "non_standard_annotation":
|
||||
return annotation["value"]
|
||||
|
||||
else:
|
||||
return dict(annotation)
|
||||
|
||||
|
||||
def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str, Any]]:
|
||||
i = 0
|
||||
n = len(blocks)
|
||||
|
||||
while i < n:
|
||||
block = blocks[i]
|
||||
|
||||
# Skip non-reasoning blocks or blocks already in Responses format
|
||||
if block.get("type") != "reasoning" or "summary" in block:
|
||||
yield dict(block)
|
||||
i += 1
|
||||
continue
|
||||
elif "reasoning" not in block and "summary" not in block:
|
||||
# {"type": "reasoning", "id": "rs_..."}
|
||||
oai_format = {**block, "summary": []}
|
||||
if "extras" in oai_format:
|
||||
oai_format.update(oai_format.pop("extras"))
|
||||
oai_format["type"] = oai_format.pop("type", "reasoning")
|
||||
if "encrypted_content" in oai_format:
|
||||
oai_format["encrypted_content"] = oai_format.pop("encrypted_content")
|
||||
yield oai_format
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
pass
|
||||
|
||||
summary: list[dict[str, str]] = [
|
||||
{"type": "summary_text", "text": block.get("reasoning", "")}
|
||||
]
|
||||
# 'common' is every field except the exploded 'reasoning'
|
||||
common = {k: v for k, v in block.items() if k != "reasoning"}
|
||||
if "extras" in common:
|
||||
common.update(common.pop("extras"))
|
||||
|
||||
i += 1
|
||||
while i < n:
|
||||
next_ = blocks[i]
|
||||
if next_.get("type") == "reasoning" and "reasoning" in next_:
|
||||
summary.append(
|
||||
{"type": "summary_text", "text": next_.get("reasoning", "")}
|
||||
)
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
merged = dict(common)
|
||||
merged["summary"] = summary
|
||||
merged["type"] = merged.pop("type", "reasoning")
|
||||
yield merged
|
||||
|
||||
|
||||
def _consolidate_calls(
|
||||
items: Iterable[dict[str, Any]],
|
||||
call_name: Literal["web_search_call", "code_interpreter_call"],
|
||||
result_name: Literal["web_search_result", "code_interpreter_result"],
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
Generator that walks through *items* and, whenever it meets the pair
|
||||
|
||||
{"type": "web_search_call", "id": X, ...}
|
||||
{"type": "web_search_result", "id": X}
|
||||
|
||||
merges them into
|
||||
|
||||
{"id": X,
|
||||
"action": …,
|
||||
"status": …,
|
||||
"type": "web_search_call"}
|
||||
|
||||
keeping every other element untouched.
|
||||
"""
|
||||
items = iter(items) # make sure we have a true iterator
|
||||
for current in items:
|
||||
# Only a call can start a pair worth collapsing
|
||||
if current.get("type") != call_name:
|
||||
yield current
|
||||
continue
|
||||
|
||||
try:
|
||||
nxt = next(items) # look-ahead one element
|
||||
except StopIteration: # no “result” – just yield the call back
|
||||
yield current
|
||||
break
|
||||
|
||||
# If this really is the matching “result” – collapse
|
||||
if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
|
||||
if call_name == "web_search_call":
|
||||
collapsed = {"id": current["id"]}
|
||||
if "action" in current:
|
||||
collapsed["action"] = current["action"]
|
||||
collapsed["status"] = current["status"]
|
||||
collapsed["type"] = "web_search_call"
|
||||
|
||||
if call_name == "code_interpreter_call":
|
||||
collapsed = {"id": current["id"]}
|
||||
for key in ("code", "container_id"):
|
||||
if key in current:
|
||||
collapsed[key] = current[key]
|
||||
elif key in current.get("extras", {}):
|
||||
collapsed[key] = current["extras"][key]
|
||||
else:
|
||||
pass
|
||||
|
||||
for key in ("outputs", "status"):
|
||||
if key in nxt:
|
||||
collapsed[key] = nxt[key]
|
||||
elif key in nxt.get("extras", {}):
|
||||
collapsed[key] = nxt["extras"][key]
|
||||
else:
|
||||
pass
|
||||
collapsed["type"] = "code_interpreter_call"
|
||||
|
||||
yield collapsed
|
||||
|
||||
else:
|
||||
# Not a matching pair – emit both, in original order
|
||||
yield current
|
||||
yield nxt
|
||||
|
||||
|
||||
def _convert_from_v1_to_responses(
|
||||
content: list[types.ContentBlock], tool_calls: list[types.ToolCall]
|
||||
) -> list[dict[str, Any]]:
|
||||
new_content: list = []
|
||||
for block in content:
|
||||
if block["type"] == "text" and "annotations" in block:
|
||||
# Need a copy because we’re changing the annotations list
|
||||
new_block = dict(block)
|
||||
new_block["annotations"] = [
|
||||
_convert_annotation_from_v1(a) for a in block["annotations"]
|
||||
]
|
||||
new_content.append(new_block)
|
||||
elif block["type"] == "tool_call":
|
||||
new_block = {"type": "function_call", "call_id": block["id"]}
|
||||
if "extras" in block and "item_id" in block["extras"]:
|
||||
new_block["id"] = block["extras"]["item_id"]
|
||||
if "name" in block:
|
||||
new_block["name"] = block["name"]
|
||||
if "extras" in block and "arguments" in block["extras"]:
|
||||
new_block["arguments"] = block["extras"]["arguments"]
|
||||
if any(key not in block for key in ("name", "arguments")):
|
||||
matching_tool_calls = [
|
||||
call for call in tool_calls if call["id"] == block["id"]
|
||||
]
|
||||
if matching_tool_calls:
|
||||
tool_call = matching_tool_calls[0]
|
||||
if "name" not in block:
|
||||
new_block["name"] = tool_call["name"]
|
||||
if "arguments" not in block:
|
||||
new_block["arguments"] = json.dumps(tool_call["args"])
|
||||
new_content.append(new_block)
|
||||
elif (
|
||||
is_data_content_block(cast(dict, block))
|
||||
and block["type"] == "image"
|
||||
and "base64" in block
|
||||
and isinstance(block.get("id"), str)
|
||||
and block["id"].startswith("ig_")
|
||||
):
|
||||
new_block = {"type": "image_generation_call", "result": block["base64"]}
|
||||
for extra_key in ("id", "status"):
|
||||
if extra_key in block:
|
||||
new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item]
|
||||
elif extra_key in block.get("extras", {}):
|
||||
new_block[extra_key] = block["extras"][extra_key]
|
||||
new_content.append(new_block)
|
||||
elif block["type"] == "non_standard" and "value" in block:
|
||||
new_content.append(block["value"])
|
||||
else:
|
||||
new_content.append(block)
|
||||
|
||||
new_content = list(_implode_reasoning_blocks(new_content))
|
||||
new_content = list(
|
||||
_consolidate_calls(new_content, "web_search_call", "web_search_result")
|
||||
)
|
||||
new_content = list(
|
||||
_consolidate_calls(
|
||||
new_content, "code_interpreter_call", "code_interpreter_result"
|
||||
)
|
||||
)
|
||||
|
||||
return new_content
|
||||
|
||||
@@ -69,6 +69,10 @@ from langchain_core.messages.ai import (
|
||||
OutputTokenDetails,
|
||||
UsageMetadata,
|
||||
)
|
||||
from langchain_core.messages.block_translators.openai import (
|
||||
translate_content,
|
||||
translate_content_chunk,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
@@ -108,6 +112,8 @@ from langchain_openai.chat_models._client_utils import (
|
||||
)
|
||||
from langchain_openai.chat_models._compat import (
|
||||
_convert_from_v03_ai_message,
|
||||
_convert_from_v1_to_chat_completions,
|
||||
_convert_from_v1_to_responses,
|
||||
_convert_to_v03_ai_message,
|
||||
)
|
||||
|
||||
@@ -202,7 +208,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _format_message_content(content: Any) -> Any:
|
||||
def _format_message_content(content: Any, responses_ai_msg: bool = False) -> Any:
|
||||
"""Format message content."""
|
||||
if content and isinstance(content, list):
|
||||
formatted_content = []
|
||||
@@ -214,7 +220,13 @@ def _format_message_content(content: Any) -> Any:
|
||||
and block["type"] in ("tool_use", "thinking", "reasoning_content")
|
||||
):
|
||||
continue
|
||||
elif isinstance(block, dict) and is_data_content_block(block):
|
||||
elif (
|
||||
isinstance(block, dict)
|
||||
and is_data_content_block(block)
|
||||
# Responses API messages handled separately in _compat (parsed into
|
||||
# image generation calls)
|
||||
and not responses_ai_msg
|
||||
):
|
||||
formatted_content.append(convert_to_openai_data_block(block))
|
||||
# Anthropic image blocks
|
||||
elif (
|
||||
@@ -247,7 +259,9 @@ def _format_message_content(content: Any) -> Any:
|
||||
return formatted_content
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
def _convert_message_to_dict(
|
||||
message: BaseMessage, responses_ai_msg: bool = False
|
||||
) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
@@ -256,7 +270,11 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: dict[str, Any] = {"content": _format_message_content(message.content)}
|
||||
message_dict: dict[str, Any] = {
|
||||
"content": _format_message_content(
|
||||
message.content, responses_ai_msg=responses_ai_msg
|
||||
)
|
||||
}
|
||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||
message_dict["name"] = name
|
||||
|
||||
@@ -291,15 +309,25 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if "function_call" in message_dict or "tool_calls" in message_dict:
|
||||
message_dict["content"] = message_dict["content"] or None
|
||||
|
||||
if "audio" in message.additional_kwargs:
|
||||
# openai doesn't support passing the data back - only the id
|
||||
# https://platform.openai.com/docs/guides/audio/multi-turn-conversations
|
||||
audio: Optional[dict[str, Any]] = None
|
||||
for block in message.content:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "audio"
|
||||
and (id_ := block.get("id"))
|
||||
and not responses_ai_msg
|
||||
):
|
||||
# openai doesn't support passing the data back - only the id
|
||||
# https://platform.openai.com/docs/guides/audio/multi-turn-conversations
|
||||
audio = {"id": id_}
|
||||
if not audio and "audio" in message.additional_kwargs:
|
||||
raw_audio = message.additional_kwargs["audio"]
|
||||
audio = (
|
||||
{"id": message.additional_kwargs["audio"]["id"]}
|
||||
if "id" in raw_audio
|
||||
else raw_audio
|
||||
)
|
||||
if audio:
|
||||
message_dict["audio"] = audio
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict["role"] = message.additional_kwargs.get(
|
||||
@@ -681,7 +709,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
.. versionadded:: 0.3.9
|
||||
"""
|
||||
|
||||
output_version: Literal["v0", "responses/v1"] = "v0"
|
||||
output_version: str = "v0"
|
||||
"""Version of AIMessage output format to use.
|
||||
|
||||
This field is used to roll-out new output formats for chat model AIMessages
|
||||
@@ -692,8 +720,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
- ``'v0'``: AIMessage format as of langchain-openai 0.3.x.
|
||||
- ``'responses/v1'``: Formats Responses API output
|
||||
items into AIMessage content blocks.
|
||||
- ``"v1"``: v1 of LangChain cross-provider standard.
|
||||
|
||||
Currently only impacts the Responses API. ``output_version='responses/v1'`` is
|
||||
Currently only impacts the Responses API. ``output_version='v1'`` is
|
||||
recommended.
|
||||
|
||||
.. versionadded:: 0.3.25
|
||||
@@ -896,6 +925,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
message=default_chunk_class(content="", usage_metadata=usage_metadata),
|
||||
generation_info=base_generation_info,
|
||||
)
|
||||
if self.output_version == "v1":
|
||||
generation_chunk.message.content = []
|
||||
generation_chunk.message.response_metadata["output_version"] = "v1"
|
||||
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
@@ -908,6 +941,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["model_provider"] = "openai"
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
@@ -923,6 +957,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
if self.output_version == "v1":
|
||||
message_chunk.content = cast(
|
||||
"Union[str, list[Union[str, dict]]]",
|
||||
translate_content_chunk(cast(AIMessageChunk, message_chunk)),
|
||||
)
|
||||
message_chunk.response_metadata["output_version"] = "v1"
|
||||
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
@@ -1216,7 +1257,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else:
|
||||
payload = _construct_responses_api_payload(messages, payload)
|
||||
else:
|
||||
payload["messages"] = [_convert_message_to_dict(m) for m in messages]
|
||||
payload["messages"] = [
|
||||
_convert_message_to_dict(_convert_from_v1_to_chat_completions(m))
|
||||
if isinstance(m, AIMessage)
|
||||
else _convert_message_to_dict(m)
|
||||
for m in messages
|
||||
]
|
||||
return payload
|
||||
|
||||
def _create_chat_result(
|
||||
@@ -1265,6 +1311,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
generations.append(gen)
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
"model_provider": "openai",
|
||||
"model_name": response_dict.get("model", self.model_name),
|
||||
"system_fingerprint": response_dict.get("system_fingerprint", ""),
|
||||
}
|
||||
@@ -1282,6 +1329,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if hasattr(message, "refusal"):
|
||||
generations[0].message.additional_kwargs["refusal"] = message.refusal
|
||||
|
||||
if self.output_version == "v1":
|
||||
generations[0].message.content = cast(
|
||||
Union[str, list[Union[str, dict]]],
|
||||
translate_content(cast(AIMessage, generations[0].message)),
|
||||
)
|
||||
generations[0].message.response_metadata["output_version"] = "v1"
|
||||
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
@@ -1496,7 +1550,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
messages: Sequence[BaseMessage],
|
||||
tools: Optional[
|
||||
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
|
||||
] = None,
|
||||
@@ -3384,6 +3438,20 @@ def _oai_structured_outputs_parser(
|
||||
return parsed
|
||||
elif ai_msg.additional_kwargs.get("refusal"):
|
||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||
elif any(
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "non_standard"
|
||||
and "refusal" in block["value"]
|
||||
for block in ai_msg.content
|
||||
):
|
||||
refusal = next(
|
||||
block["value"]["refusal"]
|
||||
for block in ai_msg.content
|
||||
if isinstance(block, dict)
|
||||
and block["type"] == "non_standard"
|
||||
and "refusal" in block["value"]
|
||||
)
|
||||
raise OpenAIRefusalError(refusal)
|
||||
elif ai_msg.tool_calls:
|
||||
return None
|
||||
else:
|
||||
@@ -3500,7 +3568,7 @@ def _get_last_messages(
|
||||
msg = messages[i]
|
||||
if isinstance(msg, AIMessage):
|
||||
response_id = msg.response_metadata.get("id")
|
||||
if response_id:
|
||||
if response_id and response_id.startswith("resp_"):
|
||||
return messages[i + 1 :], response_id
|
||||
else:
|
||||
return messages, None
|
||||
@@ -3609,23 +3677,45 @@ def _construct_responses_api_payload(
|
||||
return payload
|
||||
|
||||
|
||||
def _make_computer_call_output_from_message(message: ToolMessage) -> dict:
|
||||
computer_call_output: dict = {
|
||||
"call_id": message.tool_call_id,
|
||||
"type": "computer_call_output",
|
||||
}
|
||||
def _make_computer_call_output_from_message(
|
||||
message: ToolMessage,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
computer_call_output: Optional[dict[str, Any]] = None
|
||||
if isinstance(message.content, list):
|
||||
# Use first input_image block
|
||||
output = next(
|
||||
block
|
||||
for block in message.content
|
||||
if cast(dict, block)["type"] == "input_image"
|
||||
)
|
||||
for block in message.content:
|
||||
if (
|
||||
message.additional_kwargs.get("type") == "computer_call_output"
|
||||
and isinstance(block, dict)
|
||||
and block.get("type") == "input_image"
|
||||
):
|
||||
# Use first input_image block
|
||||
computer_call_output = {
|
||||
"call_id": message.tool_call_id,
|
||||
"type": "computer_call_output",
|
||||
"output": block,
|
||||
}
|
||||
break
|
||||
elif (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "non_standard"
|
||||
and block.get("value", {}).get("type") == "computer_call_output"
|
||||
):
|
||||
computer_call_output = block["value"]
|
||||
break
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
# string, assume image_url
|
||||
output = {"type": "input_image", "image_url": message.content}
|
||||
computer_call_output["output"] = output
|
||||
if "acknowledged_safety_checks" in message.additional_kwargs:
|
||||
if message.additional_kwargs.get("type") == "computer_call_output":
|
||||
# string, assume image_url
|
||||
computer_call_output = {
|
||||
"call_id": message.tool_call_id,
|
||||
"type": "computer_call_output",
|
||||
"output": {"type": "input_image", "image_url": message.content},
|
||||
}
|
||||
if (
|
||||
computer_call_output is not None
|
||||
and "acknowledged_safety_checks" in message.additional_kwargs
|
||||
):
|
||||
computer_call_output["acknowledged_safety_checks"] = message.additional_kwargs[
|
||||
"acknowledged_safety_checks"
|
||||
]
|
||||
@@ -3642,6 +3732,15 @@ def _make_custom_tool_output_from_message(message: ToolMessage) -> Optional[dict
|
||||
"output": block.get("output") or "",
|
||||
}
|
||||
break
|
||||
elif (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "non_standard"
|
||||
and block.get("value", {}).get("type") == "custom_tool_call_output"
|
||||
):
|
||||
custom_tool_output = block["value"]
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
||||
return custom_tool_output
|
||||
|
||||
@@ -3666,20 +3765,33 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
for lc_msg in messages:
|
||||
if isinstance(lc_msg, AIMessage):
|
||||
lc_msg = _convert_from_v03_ai_message(lc_msg)
|
||||
msg = _convert_message_to_dict(lc_msg)
|
||||
msg = _convert_message_to_dict(lc_msg, responses_ai_msg=True)
|
||||
if isinstance(msg.get("content"), list) and all(
|
||||
isinstance(block, dict) for block in msg["content"]
|
||||
):
|
||||
msg["content"] = _convert_from_v1_to_responses(
|
||||
msg["content"], lc_msg.tool_calls
|
||||
)
|
||||
else:
|
||||
msg = _convert_message_to_dict(lc_msg)
|
||||
# Get content from non-standard content blocks
|
||||
if isinstance(msg["content"], list):
|
||||
for i, block in enumerate(msg["content"]):
|
||||
if isinstance(block, dict) and block.get("type") == "non_standard":
|
||||
msg["content"][i] = block["value"]
|
||||
# "name" parameter unsupported
|
||||
if "name" in msg:
|
||||
msg.pop("name")
|
||||
if msg["role"] == "tool":
|
||||
tool_output = msg["content"]
|
||||
computer_call_output = _make_computer_call_output_from_message(
|
||||
cast(ToolMessage, lc_msg)
|
||||
)
|
||||
custom_tool_output = _make_custom_tool_output_from_message(lc_msg) # type: ignore[arg-type]
|
||||
if custom_tool_output:
|
||||
input_.append(custom_tool_output)
|
||||
elif lc_msg.additional_kwargs.get("type") == "computer_call_output":
|
||||
computer_call_output = _make_computer_call_output_from_message(
|
||||
cast(ToolMessage, lc_msg)
|
||||
)
|
||||
if computer_call_output:
|
||||
input_.append(computer_call_output)
|
||||
elif custom_tool_output:
|
||||
input_.append(custom_tool_output)
|
||||
else:
|
||||
if not isinstance(tool_output, str):
|
||||
tool_output = _stringify(tool_output)
|
||||
@@ -3828,7 +3940,7 @@ def _construct_lc_result_from_responses_api(
|
||||
response: Response,
|
||||
schema: Optional[type[_BM]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
output_version: Literal["v0", "responses/v1"] = "v0",
|
||||
output_version: str = "v0",
|
||||
) -> ChatResult:
|
||||
"""Construct ChatResponse from OpenAI Response API response."""
|
||||
if response.error:
|
||||
@@ -3855,6 +3967,7 @@ def _construct_lc_result_from_responses_api(
|
||||
if metadata:
|
||||
response_metadata.update(metadata)
|
||||
# for compatibility with chat completion calls.
|
||||
response_metadata["model_provider"] = "openai"
|
||||
response_metadata["model_name"] = response_metadata.get("model")
|
||||
if response.usage:
|
||||
usage_metadata = _create_usage_metadata_responses(response.usage.model_dump())
|
||||
@@ -3966,6 +4079,7 @@ def _construct_lc_result_from_responses_api(
|
||||
additional_kwargs["parsed"] = parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
message = AIMessage(
|
||||
content=content_blocks,
|
||||
id=response.id,
|
||||
@@ -3977,6 +4091,11 @@ def _construct_lc_result_from_responses_api(
|
||||
)
|
||||
if output_version == "v0":
|
||||
message = _convert_to_v03_ai_message(message)
|
||||
elif output_version == "v1":
|
||||
message.content = cast(
|
||||
Union[str, list[Union[str, dict]]], translate_content(message)
|
||||
)
|
||||
message.response_metadata["output_version"] = "v1"
|
||||
else:
|
||||
pass
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
@@ -3990,7 +4109,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
schema: Optional[type[_BM]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
has_reasoning: bool = False,
|
||||
output_version: Literal["v0", "responses/v1"] = "v0",
|
||||
output_version: str = "v0",
|
||||
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
|
||||
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
|
||||
"""Advance indexes tracked during streaming.
|
||||
@@ -4056,9 +4175,12 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
annotation = chunk.annotation
|
||||
else:
|
||||
annotation = chunk.annotation.model_dump(exclude_none=True, mode="json")
|
||||
content.append({"annotations": [annotation], "index": current_index})
|
||||
|
||||
content.append(
|
||||
{"type": "text", "annotations": [annotation], "index": current_index}
|
||||
)
|
||||
elif chunk.type == "response.output_text.done":
|
||||
content.append({"id": chunk.item_id, "index": current_index})
|
||||
content.append({"type": "text", "id": chunk.item_id, "index": current_index})
|
||||
elif chunk.type == "response.created":
|
||||
id = chunk.response.id
|
||||
response_metadata["id"] = chunk.response.id # Backwards compatibility
|
||||
@@ -4151,6 +4273,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
content.append({"type": "refusal", "refusal": chunk.refusal})
|
||||
elif chunk.type == "response.output_item.added" and chunk.item.type == "reasoning":
|
||||
_advance(chunk.output_index)
|
||||
current_sub_index = 0
|
||||
reasoning = chunk.item.model_dump(exclude_none=True, mode="json")
|
||||
reasoning["index"] = current_index
|
||||
content.append(reasoning)
|
||||
@@ -4164,6 +4287,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
],
|
||||
"index": current_index,
|
||||
"type": "reasoning",
|
||||
"id": chunk.item_id,
|
||||
}
|
||||
)
|
||||
elif chunk.type == "response.image_generation_call.partial_image":
|
||||
@@ -4200,6 +4324,11 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
AIMessageChunk,
|
||||
_convert_to_v03_ai_message(message, has_reasoning=has_reasoning),
|
||||
)
|
||||
elif output_version == "v1":
|
||||
message.content = cast(
|
||||
Union[str, list[Union[str, dict]]], translate_content_chunk(message)
|
||||
)
|
||||
message.response_metadata["output_version"] = "v1"
|
||||
else:
|
||||
pass
|
||||
return (
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -28,8 +28,9 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
||||
for block in response.content:
|
||||
assert isinstance(block, dict)
|
||||
if block["type"] == "text":
|
||||
assert isinstance(block["text"], str)
|
||||
for annotation in block["annotations"]:
|
||||
assert isinstance(block.get("text"), str)
|
||||
annotations = block.get("annotations", [])
|
||||
for annotation in annotations:
|
||||
if annotation["type"] == "file_citation":
|
||||
assert all(
|
||||
key in annotation
|
||||
@@ -40,8 +41,12 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
||||
key in annotation
|
||||
for key in ["end_index", "start_index", "title", "type", "url"]
|
||||
)
|
||||
|
||||
text_content = response.text()
|
||||
elif annotation["type"] == "citation":
|
||||
assert all(key in annotation for key in ["title", "type"])
|
||||
if "url" in annotation:
|
||||
assert "start_index" in annotation
|
||||
assert "end_index" in annotation
|
||||
text_content = response.text() # type: ignore[operator,misc]
|
||||
assert isinstance(text_content, str)
|
||||
assert text_content
|
||||
assert response.usage_metadata
|
||||
@@ -49,12 +54,14 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
||||
assert response.usage_metadata["output_tokens"] > 0
|
||||
assert response.usage_metadata["total_tokens"] > 0
|
||||
assert response.response_metadata["model_name"]
|
||||
assert response.response_metadata["service_tier"]
|
||||
assert response.response_metadata["service_tier"] # type: ignore[typeddict-item]
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_web_search.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_web_search() -> None:
|
||||
llm = ChatOpenAI(model=MODEL_NAME, output_version="responses/v1")
|
||||
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
|
||||
def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
|
||||
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
|
||||
first_response = llm.invoke(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
@@ -82,20 +89,9 @@ def test_web_search() -> None:
|
||||
# Manually pass in chat history
|
||||
response = llm.invoke(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What was a positive news story from today?",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "What was a positive news story from today?"},
|
||||
first_response,
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "what about a negative one"}],
|
||||
},
|
||||
{"role": "user", "content": "what about a negative one"},
|
||||
],
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
)
|
||||
@@ -108,9 +104,12 @@ def test_web_search() -> None:
|
||||
_check_response(response)
|
||||
|
||||
for msg in [first_response, full, response]:
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert msg is not None
|
||||
block_types = [block["type"] for block in msg.content] # type: ignore[index]
|
||||
assert block_types == ["web_search_call", "text"]
|
||||
if output_version == "responses/v1":
|
||||
assert block_types == ["web_search_call", "text"]
|
||||
else:
|
||||
assert block_types == ["web_search_call", "web_search_result", "text"]
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@@ -141,13 +140,15 @@ async def test_web_search_async() -> None:
|
||||
assert tool_output["type"] == "web_search_call"
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_function_calling() -> None:
|
||||
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
||||
def multiply(x: int, y: int) -> int:
|
||||
"""return x * y"""
|
||||
return x * y
|
||||
|
||||
llm = ChatOpenAI(model=MODEL_NAME)
|
||||
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
|
||||
bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}])
|
||||
ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4"))
|
||||
assert len(ai_msg.tool_calls) == 1
|
||||
@@ -174,8 +175,15 @@ class FooDict(TypedDict):
|
||||
response: str
|
||||
|
||||
|
||||
def test_parsed_pydantic_schema() -> None:
|
||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_parsed_pydantic_schema(
|
||||
output_version: Literal["v0", "responses/v1", "v1"],
|
||||
) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model=MODEL_NAME, use_responses_api=True, output_version=output_version
|
||||
)
|
||||
response = llm.invoke("how are ya", response_format=Foo)
|
||||
parsed = Foo(**json.loads(response.text()))
|
||||
assert parsed == response.additional_kwargs["parsed"]
|
||||
@@ -297,8 +305,8 @@ def test_function_calling_and_structured_output() -> None:
|
||||
|
||||
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="o4-mini", use_responses_api=True, output_version=output_version
|
||||
)
|
||||
@@ -358,27 +366,32 @@ def test_computer_calls() -> None:
|
||||
|
||||
def test_file_search() -> None:
|
||||
pytest.skip() # TODO: set up infra
|
||||
llm = ChatOpenAI(model=MODEL_NAME)
|
||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||
tool = {
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
|
||||
}
|
||||
response = llm.invoke("What is deep research by OpenAI?", tools=[tool])
|
||||
|
||||
input_message = {"role": "user", "content": "What is deep research by OpenAI?"}
|
||||
response = llm.invoke([input_message], tools=[tool])
|
||||
_check_response(response)
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]):
|
||||
for chunk in llm.stream([input_message], tools=[tool]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
_check_response(full)
|
||||
|
||||
next_message = {"role": "user", "content": "Thank you."}
|
||||
_ = llm.invoke([input_message, full, next_message])
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_stream_reasoning_summary(
|
||||
output_version: Literal["v0", "responses/v1"],
|
||||
output_version: Literal["v0", "responses/v1", "v1"],
|
||||
) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="o4-mini",
|
||||
@@ -398,7 +411,14 @@ def test_stream_reasoning_summary(
|
||||
if output_version == "v0":
|
||||
reasoning = response_1.additional_kwargs["reasoning"]
|
||||
assert set(reasoning.keys()) == {"id", "type", "summary"}
|
||||
else:
|
||||
summary = reasoning["summary"]
|
||||
assert isinstance(summary, list)
|
||||
for block in summary:
|
||||
assert isinstance(block, dict)
|
||||
assert isinstance(block["type"], str)
|
||||
assert isinstance(block["text"], str)
|
||||
assert block["text"]
|
||||
elif output_version == "responses/v1":
|
||||
reasoning = next(
|
||||
block
|
||||
for block in response_1.content
|
||||
@@ -407,13 +427,27 @@ def test_stream_reasoning_summary(
|
||||
if isinstance(reasoning, str):
|
||||
reasoning = json.loads(reasoning)
|
||||
assert set(reasoning.keys()) == {"id", "type", "summary", "index"}
|
||||
summary = reasoning["summary"]
|
||||
assert isinstance(summary, list)
|
||||
for block in summary:
|
||||
assert isinstance(block, dict)
|
||||
assert isinstance(block["type"], str)
|
||||
assert isinstance(block["text"], str)
|
||||
assert block["text"]
|
||||
summary = reasoning["summary"]
|
||||
assert isinstance(summary, list)
|
||||
for block in summary:
|
||||
assert isinstance(block, dict)
|
||||
assert isinstance(block["type"], str)
|
||||
assert isinstance(block["text"], str)
|
||||
assert block["text"]
|
||||
else:
|
||||
# v1
|
||||
total_reasoning_blocks = 0
|
||||
for block in response_1.content_blocks:
|
||||
if block["type"] == "reasoning":
|
||||
total_reasoning_blocks += 1
|
||||
assert isinstance(block.get("id"), str) and block.get(
|
||||
"id", ""
|
||||
).startswith("rs_")
|
||||
assert isinstance(block.get("reasoning"), str)
|
||||
assert isinstance(block.get("index"), str)
|
||||
assert (
|
||||
total_reasoning_blocks > 1
|
||||
) # This query typically generates multiple reasoning blocks
|
||||
|
||||
# Check we can pass back summaries
|
||||
message_2 = {"role": "user", "content": "Thank you."}
|
||||
@@ -421,9 +455,13 @@ def test_stream_reasoning_summary(
|
||||
assert isinstance(response_2, AIMessage)
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_code_interpreter() -> None:
|
||||
llm = ChatOpenAI(model="o4-mini", use_responses_api=True)
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="o4-mini", use_responses_api=True, output_version=output_version
|
||||
)
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[{"type": "code_interpreter", "container": {"type": "auto"}}]
|
||||
)
|
||||
@@ -432,16 +470,43 @@ def test_code_interpreter() -> None:
|
||||
"content": "Write and run code to answer the question: what is 3^3?",
|
||||
}
|
||||
response = llm_with_tools.invoke([input_message])
|
||||
assert isinstance(response, AIMessage)
|
||||
_check_response(response)
|
||||
tool_outputs = response.additional_kwargs["tool_outputs"]
|
||||
assert tool_outputs
|
||||
assert any(output["type"] == "code_interpreter_call" for output in tool_outputs)
|
||||
if output_version == "v0":
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.additional_kwargs["tool_outputs"]
|
||||
if item["type"] == "code_interpreter_call"
|
||||
]
|
||||
assert len(tool_outputs) == 1
|
||||
elif output_version == "responses/v1":
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||
]
|
||||
assert len(tool_outputs) == 1
|
||||
else:
|
||||
# v1
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.content_blocks
|
||||
if item["type"] == "code_interpreter_call"
|
||||
]
|
||||
code_interpreter_result = next(
|
||||
item
|
||||
for item in response.content_blocks
|
||||
if item["type"] == "code_interpreter_result"
|
||||
)
|
||||
assert tool_outputs
|
||||
assert code_interpreter_result
|
||||
assert len(tool_outputs) == 1
|
||||
|
||||
# Test streaming
|
||||
# Use same container
|
||||
tool_outputs = response.additional_kwargs["tool_outputs"]
|
||||
assert len(tool_outputs) == 1
|
||||
container_id = tool_outputs[0]["container_id"]
|
||||
container_id = tool_outputs[0].get("container_id") or tool_outputs[0].get(
|
||||
"extras", {}
|
||||
).get("container_id")
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[{"type": "code_interpreter", "container": container_id}]
|
||||
)
|
||||
@@ -451,9 +516,34 @@ def test_code_interpreter() -> None:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
tool_outputs = full.additional_kwargs["tool_outputs"]
|
||||
assert tool_outputs
|
||||
assert any(output["type"] == "code_interpreter_call" for output in tool_outputs)
|
||||
if output_version == "v0":
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.additional_kwargs["tool_outputs"]
|
||||
if item["type"] == "code_interpreter_call"
|
||||
]
|
||||
assert tool_outputs
|
||||
elif output_version == "responses/v1":
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||
]
|
||||
assert tool_outputs
|
||||
else:
|
||||
# v1
|
||||
code_interpreter_call = next(
|
||||
item
|
||||
for item in full.content_blocks
|
||||
if item["type"] == "code_interpreter_call"
|
||||
)
|
||||
code_interpreter_result = next(
|
||||
item
|
||||
for item in full.content_blocks
|
||||
if item["type"] == "code_interpreter_result"
|
||||
)
|
||||
assert code_interpreter_call
|
||||
assert code_interpreter_result
|
||||
|
||||
# Test we can pass back in
|
||||
next_message = {"role": "user", "content": "Please add more comments to the code."}
|
||||
@@ -548,10 +638,69 @@ def test_mcp_builtin_zdr() -> None:
|
||||
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_image_generation_streaming() -> None:
|
||||
@pytest.mark.default_cassette("test_mcp_builtin_zdr.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_mcp_builtin_zdr_v1() -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="o4-mini",
|
||||
output_version="v1",
|
||||
store=False,
|
||||
include=["reasoning.encrypted_content"],
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "deepwiki",
|
||||
"server_url": "https://mcp.deepwiki.com/mcp",
|
||||
"require_approval": {"always": {"tool_names": ["read_wiki_structure"]}},
|
||||
}
|
||||
]
|
||||
)
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"What transport protocols does the 2025-03-26 version of the MCP spec "
|
||||
"support?"
|
||||
),
|
||||
}
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm_with_tools.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert all(isinstance(block, dict) for block in full.content)
|
||||
|
||||
approval_message = HumanMessage(
|
||||
[
|
||||
{
|
||||
"type": "non_standard",
|
||||
"value": {
|
||||
"type": "mcp_approval_response",
|
||||
"approve": True,
|
||||
"approval_request_id": block["value"]["id"], # type: ignore[index]
|
||||
},
|
||||
}
|
||||
for block in full.content_blocks
|
||||
if block["type"] == "non_standard"
|
||||
and block["value"]["type"] == "mcp_approval_request" # type: ignore[index]
|
||||
]
|
||||
)
|
||||
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_image_generation_streaming(
|
||||
output_version: Literal["v0", "responses/v1"],
|
||||
) -> None:
|
||||
"""Test image generation streaming."""
|
||||
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4.1", use_responses_api=True, output_version=output_version
|
||||
)
|
||||
tool = {
|
||||
"type": "image_generation",
|
||||
# For testing purposes let's keep the quality low, so the test runs faster.
|
||||
@@ -598,15 +747,69 @@ def test_image_generation_streaming() -> None:
|
||||
# At the moment, the streaming API does not pick up annotations fully.
|
||||
# So the following check is commented out.
|
||||
# _check_response(complete_ai_message)
|
||||
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
if output_version == "v0":
|
||||
assert complete_ai_message.additional_kwargs["tool_outputs"]
|
||||
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
else:
|
||||
# "responses/v1"
|
||||
tool_output = next(
|
||||
block
|
||||
for block in complete_ai_message.content
|
||||
if isinstance(block, dict) and block["type"] == "image_generation_call"
|
||||
)
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_image_generation_multi_turn() -> None:
|
||||
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_image_generation_streaming_v1() -> None:
|
||||
"""Test image generation streaming."""
|
||||
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True, output_version="v1")
|
||||
tool = {
|
||||
"type": "image_generation",
|
||||
"quality": "low",
|
||||
"output_format": "jpeg",
|
||||
"output_compression": 100,
|
||||
"size": "1024x1024",
|
||||
}
|
||||
|
||||
standard_keys = {"type", "base64", "mime_type", "id", "index"}
|
||||
extra_keys = {
|
||||
"background",
|
||||
"output_format",
|
||||
"quality",
|
||||
"revised_prompt",
|
||||
"size",
|
||||
"status",
|
||||
}
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream("Draw a random short word in green font.", tools=[tool]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
complete_ai_message = cast(AIMessageChunk, full)
|
||||
|
||||
tool_output = next(
|
||||
block
|
||||
for block in complete_ai_message.content
|
||||
if isinstance(block, dict) and block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
assert set(extra_keys).issubset(tool_output["extras"].keys())
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_image_generation_multi_turn(
|
||||
output_version: Literal["v0", "responses/v1"],
|
||||
) -> None:
|
||||
"""Test multi-turn editing of image generation by passing in history."""
|
||||
# Test multi-turn
|
||||
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4.1", use_responses_api=True, output_version=output_version
|
||||
)
|
||||
# Test invocation
|
||||
tool = {
|
||||
"type": "image_generation",
|
||||
@@ -622,10 +825,41 @@ def test_image_generation_multi_turn() -> None:
|
||||
{"role": "user", "content": "Draw a random short word in green font."}
|
||||
]
|
||||
ai_message = llm_with_tools.invoke(chat_history)
|
||||
assert isinstance(ai_message, AIMessage)
|
||||
_check_response(ai_message)
|
||||
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
|
||||
|
||||
# Example tool output for an image
|
||||
expected_keys = {
|
||||
"id",
|
||||
"background",
|
||||
"output_format",
|
||||
"quality",
|
||||
"result",
|
||||
"revised_prompt",
|
||||
"size",
|
||||
"status",
|
||||
"type",
|
||||
}
|
||||
|
||||
if output_version == "v0":
|
||||
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
elif output_version == "responses/v1":
|
||||
tool_output = next(
|
||||
block
|
||||
for block in ai_message.content
|
||||
if isinstance(block, dict) and block["type"] == "image_generation_call"
|
||||
)
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
else:
|
||||
standard_keys = {"type", "base64", "id", "status"}
|
||||
tool_output = next(
|
||||
block
|
||||
for block in ai_message.content
|
||||
if isinstance(block, dict) and block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
|
||||
# Example tool output for an image (v0)
|
||||
# {
|
||||
# "background": "opaque",
|
||||
# "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8",
|
||||
@@ -641,20 +875,6 @@ def test_image_generation_multi_turn() -> None:
|
||||
# "result": # base64 encode image data
|
||||
# }
|
||||
|
||||
expected_keys = {
|
||||
"id",
|
||||
"background",
|
||||
"output_format",
|
||||
"quality",
|
||||
"result",
|
||||
"revised_prompt",
|
||||
"size",
|
||||
"status",
|
||||
"type",
|
||||
}
|
||||
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
|
||||
chat_history.extend(
|
||||
[
|
||||
# AI message with tool output
|
||||
@@ -671,9 +891,89 @@ def test_image_generation_multi_turn() -> None:
|
||||
)
|
||||
|
||||
ai_message2 = llm_with_tools.invoke(chat_history)
|
||||
assert isinstance(ai_message2, AIMessage)
|
||||
_check_response(ai_message2)
|
||||
tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
|
||||
assert set(tool_output2.keys()).issubset(expected_keys)
|
||||
|
||||
if output_version == "v0":
|
||||
tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
else:
|
||||
# "responses/v1"
|
||||
tool_output = next(
|
||||
block
|
||||
for block in ai_message2.content
|
||||
if isinstance(block, dict) and block["type"] == "image_generation_call"
|
||||
)
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_image_generation_multi_turn_v1() -> None:
|
||||
"""Test multi-turn editing of image generation by passing in history."""
|
||||
# Test multi-turn
|
||||
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True, output_version="v1")
|
||||
# Test invocation
|
||||
tool = {
|
||||
"type": "image_generation",
|
||||
"quality": "low",
|
||||
"output_format": "jpeg",
|
||||
"output_compression": 100,
|
||||
"size": "1024x1024",
|
||||
}
|
||||
llm_with_tools = llm.bind_tools([tool])
|
||||
|
||||
chat_history: list[MessageLikeRepresentation] = [
|
||||
{"role": "user", "content": "Draw a random short word in green font."}
|
||||
]
|
||||
ai_message = llm_with_tools.invoke(chat_history)
|
||||
assert isinstance(ai_message, AIMessage)
|
||||
_check_response(ai_message)
|
||||
|
||||
standard_keys = {"type", "base64", "mime_type", "id"}
|
||||
extra_keys = {
|
||||
"background",
|
||||
"output_format",
|
||||
"quality",
|
||||
"revised_prompt",
|
||||
"size",
|
||||
"status",
|
||||
}
|
||||
|
||||
tool_output = next(
|
||||
block
|
||||
for block in ai_message.content
|
||||
if isinstance(block, dict) and block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
assert set(extra_keys).issubset(tool_output["extras"].keys())
|
||||
|
||||
chat_history.extend(
|
||||
[
|
||||
# AI message with tool output
|
||||
ai_message,
|
||||
# New request
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Now, change the font to blue. Keep the word and everything else "
|
||||
"the same."
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ai_message2 = llm_with_tools.invoke(chat_history)
|
||||
assert isinstance(ai_message2, AIMessage)
|
||||
_check_response(ai_message2)
|
||||
|
||||
tool_output = next(
|
||||
block
|
||||
for block in ai_message2.content
|
||||
if isinstance(block, dict) and block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
assert set(extra_keys).issubset(tool_output["extras"].keys())
|
||||
|
||||
|
||||
def test_verbosity_parameter() -> None:
|
||||
@@ -689,14 +989,16 @@ def test_verbosity_parameter() -> None:
|
||||
assert response.content
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_custom_tool() -> None:
|
||||
@pytest.mark.default_cassette("test_custom_tool.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
|
||||
def test_custom_tool(output_version: Literal["responses/v1", "v1"]) -> None:
|
||||
@custom_tool
|
||||
def execute_code(code: str) -> str:
|
||||
"""Execute python code."""
|
||||
return "27"
|
||||
|
||||
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1").bind_tools(
|
||||
llm = ChatOpenAI(model="gpt-5", output_version=output_version).bind_tools(
|
||||
[execute_code]
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from openai.types.responses import ResponseOutputMessage, ResponseReasoningItem
|
||||
from openai.types.responses.response import IncompleteDetails, Response, ResponseUsage
|
||||
from openai.types.responses.response import IncompleteDetails, Response
|
||||
from openai.types.responses.response_error import ResponseError
|
||||
from openai.types.responses.response_file_search_tool_call import (
|
||||
ResponseFileSearchToolCall,
|
||||
@@ -43,6 +43,7 @@ from openai.types.responses.response_reasoning_item import Summary
|
||||
from openai.types.responses.response_usage import (
|
||||
InputTokensDetails,
|
||||
OutputTokensDetails,
|
||||
ResponseUsage,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from typing_extensions import TypedDict
|
||||
@@ -51,6 +52,8 @@ from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models._compat import (
|
||||
_FUNCTION_CALL_IDS_MAP_KEY,
|
||||
_convert_from_v03_ai_message,
|
||||
_convert_from_v1_to_chat_completions,
|
||||
_convert_from_v1_to_responses,
|
||||
_convert_to_v03_ai_message,
|
||||
)
|
||||
from langchain_openai.chat_models.base import (
|
||||
@@ -1231,7 +1234,7 @@ def test_structured_outputs_parser() -> None:
|
||||
serialized = dumps(llm_output)
|
||||
deserialized = loads(serialized)
|
||||
assert isinstance(deserialized, ChatGeneration)
|
||||
result = output_parser.invoke(deserialized.message)
|
||||
result = output_parser.invoke(cast(AIMessage, deserialized.message))
|
||||
assert result == parsed_response
|
||||
|
||||
|
||||
@@ -2374,7 +2377,7 @@ def test_mcp_tracing() -> None:
|
||||
assert payload["tools"][0]["headers"]["Authorization"] == "Bearer PLACEHOLDER"
|
||||
|
||||
|
||||
def test_compat() -> None:
|
||||
def test_compat_responses_v03() -> None:
|
||||
# Check compatibility with v0.3 message format
|
||||
message_v03 = AIMessage(
|
||||
content=[
|
||||
@@ -2435,6 +2438,152 @@ def test_compat() -> None:
|
||||
assert message_v03_output is not message_v03
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_v1, expected",
|
||||
[
|
||||
(
|
||||
AIMessage(
|
||||
[
|
||||
{"type": "reasoning", "reasoning": "Reasoning text"},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"args": {"location": "San Francisco"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Hello, world!",
|
||||
"annotations": [
|
||||
{"type": "citation", "url": "https://example.com"}
|
||||
],
|
||||
},
|
||||
],
|
||||
id="chatcmpl-123",
|
||||
response_metadata={"model_provider": "openai", "model_name": "gpt-4.1"},
|
||||
),
|
||||
AIMessage(
|
||||
[{"type": "text", "text": "Hello, world!"}],
|
||||
id="chatcmpl-123",
|
||||
response_metadata={"model_provider": "openai", "model_name": "gpt-4.1"},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_convert_from_v1_to_chat_completions(
|
||||
message_v1: AIMessage, expected: AIMessage
|
||||
) -> None:
|
||||
result = _convert_from_v1_to_chat_completions(message_v1)
|
||||
assert result == expected
|
||||
assert result.tool_calls == message_v1.tool_calls # tool calls remain cached
|
||||
|
||||
# Check no mutation
|
||||
assert message_v1 != result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_v1, expected",
|
||||
[
|
||||
(
|
||||
AIMessage(
|
||||
content_blocks=[
|
||||
{"type": "reasoning", "id": "abc123"},
|
||||
{"type": "reasoning", "id": "abc234", "reasoning": "foo "},
|
||||
{"type": "reasoning", "id": "abc234", "reasoning": "bar"},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"args": {"location": "San Francisco"},
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_234",
|
||||
"name": "get_weather_2",
|
||||
"args": {"location": "New York"},
|
||||
"extras": {"item_id": "fc_123"},
|
||||
},
|
||||
{"type": "text", "text": "Hello "},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "world",
|
||||
"annotations": [
|
||||
{"type": "citation", "url": "https://example.com"},
|
||||
{
|
||||
"type": "citation",
|
||||
"title": "my doc",
|
||||
"extras": {"file_id": "file_123", "index": 1},
|
||||
},
|
||||
{
|
||||
"type": "non_standard_annotation",
|
||||
"value": {"bar": "baz"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"type": "image", "base64": "...", "id": "ig_123"},
|
||||
{
|
||||
"type": "non_standard",
|
||||
"value": {"type": "something_else", "foo": "bar"},
|
||||
},
|
||||
],
|
||||
id="resp123",
|
||||
),
|
||||
[
|
||||
{"type": "reasoning", "id": "abc123", "summary": []},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "abc234",
|
||||
"summary": [
|
||||
{"type": "summary_text", "text": "foo "},
|
||||
{"type": "summary_text", "text": "bar"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": "call_123",
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
},
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": "call_234",
|
||||
"name": "get_weather_2",
|
||||
"arguments": '{"location": "New York"}',
|
||||
"id": "fc_123",
|
||||
},
|
||||
{"type": "text", "text": "Hello "},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "world",
|
||||
"annotations": [
|
||||
{"type": "url_citation", "url": "https://example.com"},
|
||||
{
|
||||
"type": "file_citation",
|
||||
"filename": "my doc",
|
||||
"index": 1,
|
||||
"file_id": "file_123",
|
||||
},
|
||||
{"bar": "baz"},
|
||||
],
|
||||
},
|
||||
{"type": "image_generation_call", "id": "ig_123", "result": "..."},
|
||||
{"type": "something_else", "foo": "bar"},
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_convert_from_v1_to_responses(
|
||||
message_v1: AIMessage, expected: list[dict[str, Any]]
|
||||
) -> None:
|
||||
result = _convert_from_v1_to_responses(
|
||||
message_v1.content_blocks, message_v1.tool_calls
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
# Check no mutation
|
||||
assert message_v1 != result
|
||||
|
||||
|
||||
def test_get_last_messages() -> None:
|
||||
messages: list[BaseMessage] = [HumanMessage("Hello")]
|
||||
last_messages, previous_response_id = _get_last_messages(messages)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||
from openai.types.responses import (
|
||||
ResponseCompletedEvent,
|
||||
@@ -20,7 +21,7 @@ from openai.types.responses import (
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
)
|
||||
from openai.types.responses.response import Response, ResponseUsage
|
||||
from openai.types.responses.response import Response
|
||||
from openai.types.responses.response_output_text import ResponseOutputText
|
||||
from openai.types.responses.response_reasoning_item import Summary
|
||||
from openai.types.responses.response_reasoning_summary_part_added_event import (
|
||||
@@ -32,6 +33,7 @@ from openai.types.responses.response_reasoning_summary_part_done_event import (
|
||||
from openai.types.responses.response_usage import (
|
||||
InputTokensDetails,
|
||||
OutputTokensDetails,
|
||||
ResponseUsage,
|
||||
)
|
||||
from openai.types.shared.reasoning import Reasoning
|
||||
from openai.types.shared.response_format_text import ResponseFormatText
|
||||
@@ -337,7 +339,7 @@ responses_stream = [
|
||||
id="rs_234",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
encrypted_content=None,
|
||||
encrypted_content="encrypted-content",
|
||||
status=None,
|
||||
),
|
||||
output_index=2,
|
||||
@@ -416,7 +418,7 @@ responses_stream = [
|
||||
Summary(text="still more reasoning", type="summary_text"),
|
||||
],
|
||||
type="reasoning",
|
||||
encrypted_content=None,
|
||||
encrypted_content="encrypted-content",
|
||||
status=None,
|
||||
),
|
||||
output_index=2,
|
||||
@@ -562,7 +564,7 @@ responses_stream = [
|
||||
Summary(text="still more reasoning", type="summary_text"),
|
||||
],
|
||||
type="reasoning",
|
||||
encrypted_content=None,
|
||||
encrypted_content="encrypted-content",
|
||||
status=None,
|
||||
),
|
||||
ResponseOutputMessage(
|
||||
@@ -620,8 +622,104 @@ def _strip_none(obj: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def test_responses_stream() -> None:
|
||||
llm = ChatOpenAI(model="o4-mini", output_version="responses/v1")
|
||||
@pytest.mark.parametrize(
|
||||
"output_version, expected_content",
|
||||
[
|
||||
(
|
||||
"responses/v1",
|
||||
[
|
||||
{
|
||||
"id": "rs_123",
|
||||
"summary": [
|
||||
{
|
||||
"index": 0,
|
||||
"type": "summary_text",
|
||||
"text": "reasoning block one",
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"type": "summary_text",
|
||||
"text": "another reasoning block",
|
||||
},
|
||||
],
|
||||
"type": "reasoning",
|
||||
"index": 0,
|
||||
},
|
||||
{"type": "text", "text": "text block one", "index": 1, "id": "msg_123"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "another text block",
|
||||
"index": 2,
|
||||
"id": "msg_123",
|
||||
},
|
||||
{
|
||||
"id": "rs_234",
|
||||
"summary": [
|
||||
{"index": 0, "type": "summary_text", "text": "more reasoning"},
|
||||
{
|
||||
"index": 1,
|
||||
"type": "summary_text",
|
||||
"text": "still more reasoning",
|
||||
},
|
||||
],
|
||||
"encrypted_content": "encrypted-content",
|
||||
"type": "reasoning",
|
||||
"index": 3,
|
||||
},
|
||||
{"type": "text", "text": "more", "index": 4, "id": "msg_234"},
|
||||
{"type": "text", "text": "text", "index": 5, "id": "msg_234"},
|
||||
],
|
||||
),
|
||||
(
|
||||
"v1",
|
||||
[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "reasoning block one",
|
||||
"id": "rs_123",
|
||||
"index": "lc_rs_305f30",
|
||||
},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "another reasoning block",
|
||||
"id": "rs_123",
|
||||
"index": "lc_rs_305f31",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "text block one",
|
||||
"index": "lc_txt_1",
|
||||
"id": "msg_123",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "another text block",
|
||||
"index": "lc_txt_2",
|
||||
"id": "msg_123",
|
||||
},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "more reasoning",
|
||||
"id": "rs_234",
|
||||
"extras": {"encrypted_content": "encrypted-content"},
|
||||
"index": "lc_rs_335f30",
|
||||
},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "still more reasoning",
|
||||
"id": "rs_234",
|
||||
"index": "lc_rs_335f31",
|
||||
},
|
||||
{"type": "text", "text": "more", "index": "lc_txt_4", "id": "msg_234"},
|
||||
{"type": "text", "text": "text", "index": "lc_txt_5", "id": "msg_234"},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_responses_stream(output_version: str, expected_content: list[dict]) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="o4-mini", use_responses_api=True, output_version=output_version
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
|
||||
@@ -630,36 +728,14 @@ def test_responses_stream() -> None:
|
||||
mock_client.responses.create = mock_create
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks = []
|
||||
with patch.object(llm, "root_client", mock_client):
|
||||
for chunk in llm.stream("test"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
chunks.append(chunk)
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
|
||||
expected_content = [
|
||||
{
|
||||
"id": "rs_123",
|
||||
"summary": [
|
||||
{"index": 0, "type": "summary_text", "text": "reasoning block one"},
|
||||
{"index": 1, "type": "summary_text", "text": "another reasoning block"},
|
||||
],
|
||||
"type": "reasoning",
|
||||
"index": 0,
|
||||
},
|
||||
{"type": "text", "text": "text block one", "index": 1, "id": "msg_123"},
|
||||
{"type": "text", "text": "another text block", "index": 2, "id": "msg_123"},
|
||||
{
|
||||
"id": "rs_234",
|
||||
"summary": [
|
||||
{"index": 0, "type": "summary_text", "text": "more reasoning"},
|
||||
{"index": 1, "type": "summary_text", "text": "still more reasoning"},
|
||||
],
|
||||
"type": "reasoning",
|
||||
"index": 3,
|
||||
},
|
||||
{"type": "text", "text": "more", "index": 4, "id": "msg_234"},
|
||||
{"type": "text", "text": "text", "index": 5, "id": "msg_234"},
|
||||
]
|
||||
assert full.content == expected_content
|
||||
assert full.additional_kwargs == {}
|
||||
assert full.id == "resp_123"
|
||||
|
||||
Reference in New Issue
Block a user