This commit is contained in:
Bagatur 2024-08-29 17:18:41 -07:00
parent fb002faba4
commit 992884b410
3 changed files with 825 additions and 127 deletions

View File

@ -20,6 +20,7 @@ from typing import (
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
@ -28,7 +29,7 @@ from typing import (
Type,
Union,
cast,
overload, Iterator,
overload,
)
from langchain_core.messages.ai import AIMessage, AIMessageChunk
@ -41,7 +42,9 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
from langchain_core.messages.tool import (
tool_call as create_tool_call,
tool_call_chunk as create_tool_call_chunk
)
from langchain_core.messages.tool import (
tool_call_chunk as create_tool_call_chunk,
)
if TYPE_CHECKING:
@ -335,50 +338,6 @@ def convert_to_messages(
return [_convert_to_message(m, copy=copy) for m in messages]
def _runnable_generator(func: Callable) -> Callable:
@overload
def wrapped(
messages: Literal[None] = None, **kwargs: Any
) -> Runnable[
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
Union[BaseMessage, List[BaseMessage]],
]: ...
@overload
def wrapped(
messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any
) -> List[BaseMessage]: ...
@overload
def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ...
def wrapped(
messages: Union[
MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None
] = None,
**kwargs: Any,
) -> Union[
BaseMessage,
List[BaseMessage],
Runnable[
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
Union[BaseMessage, List[BaseMessage]],
],
]:
from langchain_core.runnables.base import RunnableGenerator
if messages is not None:
return func(messages, **kwargs)
else:
def transform(input_: Iterator, **kwargs: Any) -> Iterator:
for x in input_:
yield func(x, **kwargs)
return RunnableGenerator(partial(transform, **kwargs), name=func.__name__)
wrapped.__doc__ = func.__doc__
return wrapped
def _runnable_support(func: Callable) -> Callable:
@overload
def wrapped(
@ -903,13 +862,84 @@ def trim_messages(
)
def _runnable_generator(func: Callable) -> Callable:
@overload
def wrapped(
messages: Literal[None] = None, **kwargs: Any
) -> Runnable[
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
Union[BaseMessage, List[BaseMessage]],
]: ...
@overload
def wrapped(
messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any
) -> List[BaseMessage]: ...
@overload
def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ...
def wrapped(
messages: Union[
MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None
] = None,
**kwargs: Any,
) -> Union[
BaseMessage,
List[BaseMessage],
Runnable[
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
Union[BaseMessage, List[BaseMessage]],
],
]:
from langchain_core.runnables.base import RunnableGenerator
if messages is not None:
return func(messages, **kwargs)
else:
def transform(input_: Iterator, **kwargs: Any) -> Iterator:
block_indexes = set()
for x in input_:
msg = func(x, **kwargs)
# Special handling for transforming an OpenAI stream to an
# Anthropic stream.
if isinstance(msg, AIMessageChunk) and isinstance(
msg.content, list
):
tool_use_ct = 0
for block in msg.content:
if not isinstance(block, dict):
continue
if "index" in block:
block_indexes.add(block["index"])
elif block.get("type") == "tool_use":
block["index"] = max(len(block_indexes), *block_indexes)
msg.tool_call_chunks[tool_use_ct]["index"] = block[
"index"
]
else:
pass
if block.get("type") == "tool_use":
tool_use_ct += 1
else:
block_indexes.add(0)
yield msg
return RunnableGenerator(partial(transform, **kwargs), name=func.__name__)
wrapped.__doc__ = func.__doc__
return wrapped
@_runnable_generator
def format_messages_as(
messages: Union[MessageLikeRepresentation, Iterator[MessageLikeRepresentation]],
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
*,
format: Literal["openai", "anthropic"],
text: Literal["string", "block"],
) -> Union[BaseMessage, Iterator[BaseMessage]]:
) -> Union[BaseMessage, List[BaseMessage]]:
"""Convert message contents into a standard format.
.. versionadded:: 0.2.36
@ -975,61 +1005,63 @@ def format_messages_as(
from langchain_core.messages import format_messages_as
from langchain.chat_models import init_chat_model
formatter = format_messages_as(format="openai", text="block")
formatter = format_messages_as(format="openai", text="string")
llm = init_chat_model() | formatter
llm.invoke(
[{"role": "user", "content": "how are you"}],
config={"model": "gpt-4o"},
)
# -> AIMessage([{"type": "text", "text": ""}], ...)
# -> AIMessage(["I am good..."], ...)
llm.invoke(
[{"role": "user", "content": "whats your name"}],
config={"model": "claude-3-5-sonnet-20240620"})
# -> AIMessage([{"type": "text", "text": ""}], ...)
config={"model": "claude-3-5-sonnet-20240620"}
)
# -> AIMessage(["My name is...], ...)
.. note:: Doesn't support streaming
This util does not support formatting streamed chunks on the fly (i.e.
"transforming" chunks). This means if you pipe the outputs of a model to this
formatter in a chain, the chain will not have token-level streaming when
using ``chain.stream()/.astream()``. You'll still see the
token stream when using ``chat.astream_events()`` but the message chunks will
not yet be formatted.
def multiply(a: int, b: int) -> int:
'''Return product of a and b.'''
return a * b
.. code-block:: python
llm_with_tools = init_chat_model().bind_tools([multiply]) | formatter
from langchain_core.messages import format_messages_as
from langchain.chat_models import init_chat_model
formatter = format_messages_as(format="openai", text="block")
llm = init_chat_model() | formatter
# Will contain a single, completed chunk.
list(llm.stream(
[{"role": "user", "content": "how are you"}],
config={"model": "gpt-4o"},
))
# Will include token-level events, but the streamed chunks will not yet be
# formatted.
async for chunk in llm.astream_events(
[{"role": "user", "content": "how are you"}],
config={"model": "gpt-4o"},
version="v2",
for chunk in llm_with_tools.stream(
"what's 5 times 2", config={"model": "claude-3-5-sonnet-20240620"}
):
...
print(chunk)
# -> AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', usage_metadata={'input_tokens': 370, 'output_tokens': 0, 'total_tokens': 370}),
# AIMessageChunk(content='Certainly', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content='! To', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' calculate', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' 5 times ', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content='2, we can use', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' the "', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content='multiply" function that', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content="'s", id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' available to', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' us.', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' Let', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content="'s use", id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' this tool', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' to', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' get', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content=' the result.', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': 'multiply', 'args': {}, 'id': 'toolu_01PW8o6BkATCecjsJX8QgG6z', 'type': 'tool_call'}], tool_call_chunks=[{'name': 'multiply', 'args': '', 'id': 'toolu_01PW8o6BkATCecjsJX8QgG6z', 'index': 1, 'type': 'tool_call_chunk'}]),
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': '', 'args': {}, 'id': None, 'type': 'tool_call'}], tool_call_chunks=[{'name': None, 'args': '', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': '', 'args': {'a': 5}, 'id': None, 'type': 'tool_call'}], tool_call_chunks=[{'name': None, 'args': '{"a": 5', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', invalid_tool_calls=[{'name': None, 'args': ', "b": 2}', 'id': None, 'error': None, 'type': 'invalid_tool_call'}], tool_call_chunks=[{'name': None, 'args': ', "b": 2}', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
# AIMessageChunk(content='', response_metadata={'stop_reason': 'tool_use', 'stop_sequence': None}, id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', usage_metadata={'input_tokens': 0, 'output_tokens': 104, 'total_tokens': 104})
""" # noqa: E501
if is_single := isinstance(messages, (BaseMessage, dict)):
messages = [messages]
messages = convert_to_messages(messages, copy=True)
if format.lower() == "openai":
formatted = _format_contents_as_openai(messages, text=text)
formatted = _format_messages_as_openai(messages, text=text)
elif format.lower() == "anthropic":
formatted = _format_contents_as_anthropic(messages, text=text)
formatted = _format_messages_as_anthropic(messages, text=text)
else:
raise ValueError(
f"Unrecognized {format=}. Expected one of ('openai', 'anthropic')."
@ -1040,7 +1072,7 @@ def format_messages_as(
return formatted
def _format_contents_as_openai(
def _format_messages_as_openai(
messages: Sequence[BaseMessage], *, text: Literal["string", "block"]
) -> List[BaseMessage]:
"""Mutates messages so their contents match OpenAI messages API."""
@ -1146,14 +1178,14 @@ def _format_contents_as_openai(
f"content block:\n\n{block}"
)
elif block.get("type") == "tool_use":
if not isinstance(message, BaseMessageChunk):
if not isinstance(message, AIMessageChunk):
if missing := [
k for k in ("id", "name", "input") if k not in block
]:
raise ValueError(
f"Unrecognized content block at "
f"messages[{i}].content[{j}] has 'type': 'tool_use', "
f"but is missing expected key(s) "
f"messages[{i}].content[{j}] has 'type': "
f"'tool_use', but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
if not any(
@ -1169,7 +1201,14 @@ def _format_contents_as_openai(
)
else:
if not message.tool_call_chunks:
message.tool_call_chunks = [create_tool_call_chunk(id=block.get("id"), index=block.get("index"), args=block.get("partial_json"), name=block.get("name"))]
message.tool_call_chunks = [
create_tool_call_chunk(
id=block.get("id"),
index=block.get("index"),
args=block.get("partial_json"),
name=block.get("name"),
)
]
elif block.get("type") == "tool_result":
if missing := [
k for k in ("content", "tool_use_id") if k not in block
@ -1187,7 +1226,7 @@ def _format_contents_as_openai(
)
# Recurse to make sure tool message contents are OpenAI format.
tool_messages.extend(
_format_contents_as_openai([tool_message], text=text)
_format_messages_as_openai([tool_message], text=text)
)
elif (block.get("type") == "json") or "json" in block:
if "json" not in block:
@ -1253,15 +1292,20 @@ def _format_contents_as_openai(
f"Anthropic, Bedrock Converse, or VertexAI format. Full "
f"content block:\n\n{block}"
)
message.content = content # type: ignore[assignment]
if text == "string" and not any(
block["type"] != "text" for block in content
):
message.content = "\n".join(block["text"] for block in content)
else:
message.content = content # type: ignore[assignment]
updated_messages.extend([message, *tool_messages])
return updated_messages
_OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error")
_OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error", "index")
def _format_contents_as_anthropic(
def _format_messages_as_anthropic(
messages: Sequence[BaseMessage], *, text: Literal["string", "block"]
) -> List[BaseMessage]:
"""Mutates messages so their contents match Anthropic messages API."""
@ -1289,7 +1333,10 @@ def _format_contents_as_anthropic(
if text == "string":
pass
else:
message.content = [{"type": "text", "text": message.content}]
text_block: dict = {"type": "text", "text": message.content}
if isinstance(message, AIMessageChunk):
text_block["index"] = 0
message.content = [text_block]
else:
if text == "string" and all(
isinstance(block, str)
@ -1303,13 +1350,20 @@ def _format_contents_as_anthropic(
else:
content = []
for j, block in enumerate(message.content):
# OpenAI format
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif block.get("type") == "text":
if isinstance(block, dict):
block_extra = {
k: block[k] for k in _OPTIONAL_ANTHROPIC_KEYS if k in block
}
else:
block_extra = {}
# OpenAI format
if isinstance(block, str):
text_block = {"type": "text", "text": block}
if isinstance(message, AIMessageChunk):
text_block["index"] = 0
content.append(text_block)
elif block.get("type") == "text":
if missing := [k for k in ("text",) if k not in block]:
raise ValueError(
f"Unrecognized content block at "
@ -1380,35 +1434,72 @@ def _format_contents_as_anthropic(
f"content block:\n\n{block}"
)
elif block.get("type") == "tool_use":
if missing := [
k for k in ("id", "name", "input") if k not in block
]:
raise ValueError(
f"Unrecognized content block at "
f"messages[{i}].content[{j}] has 'type': 'tool_use', "
f"but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
content.append(
{
"type": "tool_use",
"name": block["name"],
"id": block["id"],
"input": block["input"],
**block_extra,
}
)
if not any(
tool_call["id"] == block["id"]
for tool_call in cast(AIMessage, message).tool_calls
):
cast(AIMessage, message).tool_calls.append(
create_tool_call(
name=block["name"],
id=block["id"],
args=block["input"],
if not isinstance(message, AIMessageChunk):
if missing := [
k for k in ("id", "name", "input") if k not in block
]:
raise ValueError(
f"Unrecognized content block at "
f"messages[{i}].content[{j}] has 'type': "
f"'tool_use', "
f"but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
content.append(
{
"type": "tool_use",
"name": block["name"],
"id": block["id"],
"input": block["input"],
**block_extra,
}
)
if not any(
tool_call["id"] == block["id"]
for tool_call in cast(AIMessage, message).tool_calls
):
cast(AIMessage, message).tool_calls.append(
create_tool_call(
name=block["name"],
id=block["id"],
args=block["input"],
)
)
else:
if (
not any(k in block for k in ("input", "partial_json"))
or "index" not in block
):
raise ValueError(
f"Unrecognized content block at "
f"message_chunks[{i}].content[{j}] has "
f"'type': 'tool_use', "
f"but is does not have either an 'input' or "
f"'partial_json' and an 'index' key. Full content "
f"block:\n\n{block}"
)
content.append(
{
"type": "tool_use",
"index": block["index"],
**{
k: block[k]
for k in ("name", "input", "partial_json", "id")
if k in block
},
}
)
if not message.tool_call_chunks:
message.tool_call_chunks = [
create_tool_call_chunk(
name=block.get("name"),
id=block.get("id"),
index=block["index"],
args=block["partial_json"]
if "partial_json" in block
else block["input"],
)
]
elif block.get("type") == "tool_result":
if missing := [
k for k in ("content", "tool_use_id") if k not in block
@ -1489,7 +1580,35 @@ def _format_contents_as_anthropic(
)
message.content = content # type: ignore[assignment]
if isinstance(message, AIMessage) and message.tool_calls:
if isinstance(message, AIMessageChunk) and message.tool_call_chunks:
if isinstance(message.content, str):
if message.content:
message.content = [
{"type": "text", "text": message.content, "index": 0}
]
else:
message.content = []
if not any(
cast(dict, block).get("type") == "tool_use" for block in message.content
):
tool_use_blocks = [
# Note: we intentionally omit index so that it can be set by the
# stream handler, which can count how many blocks
# have been seen in preceding chunks
{
"type": "tool_use",
"partial_json": tc_chunk["args"],
"id": tc_chunk["id"],
"name": tc_chunk["name"],
}
for i, tc_chunk in enumerate(message.tool_call_chunks)
]
tool_use_blocks = [
{k: v for k, v in tu_block.items() if v is not None}
for tu_block in tool_use_blocks
]
message.content.extend(tool_use_blocks)
elif isinstance(message, AIMessage) and message.tool_calls:
if isinstance(message.content, str):
message.content = [{"type": "text", "text": message.content}]
for tool_call in message.tool_calls:

View File

@ -1,11 +1,12 @@
import json
from typing import Dict, List, Type
from typing import Any, Callable, Dict, Iterator, List, Type
import pytest
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
@ -20,6 +21,7 @@ from langchain_core.messages.utils import (
merge_message_runs,
trim_messages,
)
from langchain_core.runnables import RunnableLambda
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
@ -810,3 +812,580 @@ def test_format_messages_as_declarative() -> None:
result = formatter.invoke(messages)
assert result[0].content[1]["type"] == "image_url"
assert result[0].content[1]["image_url"]["url"] == base64_image
def _stream_oai(input_: Any) -> Iterator:
chunks = [
AIMessageChunk(content=""),
AIMessageChunk(content="Certainly"),
AIMessageChunk(content="!"),
AIMessageChunk(
content="",
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": 0,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": 0,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": 0,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": 0,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
),
]
yield from chunks
def _stream_anthropic(input_: Any) -> Iterator:
chunks = [
AIMessageChunk(content=[]),
AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]),
AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]),
AIMessageChunk(
content=[
{
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"input": {},
"name": "multiply",
"type": "tool_use",
"index": 1,
}
],
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": '{"', "type": "tool_use", "index": 1}],
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}],
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": "}", "type": "tool_use", "index": 1}],
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(content=""),
]
yield from chunks
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_openai_string_stream(stream: Callable) -> None:
formatter = format_messages_as(format="openai", text="string")
chain = RunnableLambda(stream) | formatter
tool_call_idx = 1 if stream == _stream_anthropic else 0
expected = [
AIMessageChunk(content=""),
AIMessageChunk(content="Certainly"),
AIMessageChunk(content="!"),
AIMessageChunk(
content="",
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
),
]
actual = list(chain.stream({}))
assert expected == actual
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_openai_block_stream(stream: Callable) -> None:
formatter = format_messages_as(format="openai", text="block")
chain = RunnableLambda(stream) | formatter
tool_call_idx = 1 if stream == _stream_anthropic else 0
expected = [
AIMessageChunk(content=[]),
AIMessageChunk(content=[{"type": "text", "text": "Certainly"}]),
AIMessageChunk(content=[{"type": "text", "text": "!"}]),
AIMessageChunk(
content=[],
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[],
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[],
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[],
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[],
),
]
actual = list(chain.stream({}))
assert expected == actual
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_anthropic_block_stream(stream: Callable) -> None:
formatter = format_messages_as(format="anthropic", text="block")
chain = RunnableLambda(stream) | formatter
expected = [
AIMessageChunk(content=[]),
AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]),
AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]),
AIMessageChunk(
content=[
{
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"name": "multiply",
"type": "tool_use",
"index": 1,
**(
{"input": {}}
if stream == _stream_anthropic
else {"partial_json": ""}
),
}
],
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": '{"', "type": "tool_use", "index": 1}],
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}],
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": "}", "type": "tool_use", "index": 1}],
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(content=[]),
]
actual = list(chain.stream({}))
assert expected == actual
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_anthropic_string_stream(stream: Callable) -> None:
formatter = format_messages_as(format="anthropic", text="string")
chain = RunnableLambda(stream) | formatter
expected = [
AIMessageChunk(content=""),
AIMessageChunk(content="Certainly"),
AIMessageChunk(content="!"),
AIMessageChunk(
content=[
{
"type": "tool_use",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"name": "multiply",
"index": 1,
**(
{"input": {}}
if stream == _stream_anthropic
else {"partial_json": ""}
),
},
],
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[
{"type": "tool_use", "partial_json": '{"', "index": 1},
],
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[
{"type": "tool_use", "partial_json": 'a": 5, "b": 2', "index": 1},
],
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[
{"type": "tool_use", "partial_json": "}", "index": 1},
],
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(content=""),
]
actual = list(chain.stream({}))
assert expected == actual

View File

@ -31,7 +31,7 @@ from langchain_core.language_models.chat_models import (
)
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers import RunLog, RunLogPatch
@ -512,7 +512,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
return model
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
config = config or {}
config = ensure_config(config)
model_params = {
_remove_prefix(k, self._config_prefix): v
for k, v in config.get("configurable", {}).items()