mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-24 12:00:52 +00:00
fmt
This commit is contained in:
parent
fb002faba4
commit
992884b410
@ -20,6 +20,7 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
@ -28,7 +29,7 @@ from typing import (
|
|||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
overload, Iterator,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||||
@ -41,7 +42,9 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
|||||||
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
|
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
|
||||||
from langchain_core.messages.tool import (
|
from langchain_core.messages.tool import (
|
||||||
tool_call as create_tool_call,
|
tool_call as create_tool_call,
|
||||||
tool_call_chunk as create_tool_call_chunk
|
)
|
||||||
|
from langchain_core.messages.tool import (
|
||||||
|
tool_call_chunk as create_tool_call_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -335,50 +338,6 @@ def convert_to_messages(
|
|||||||
return [_convert_to_message(m, copy=copy) for m in messages]
|
return [_convert_to_message(m, copy=copy) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
def _runnable_generator(func: Callable) -> Callable:
|
|
||||||
@overload
|
|
||||||
def wrapped(
|
|
||||||
messages: Literal[None] = None, **kwargs: Any
|
|
||||||
) -> Runnable[
|
|
||||||
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
|
||||||
Union[BaseMessage, List[BaseMessage]],
|
|
||||||
]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def wrapped(
|
|
||||||
messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any
|
|
||||||
) -> List[BaseMessage]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ...
|
|
||||||
|
|
||||||
def wrapped(
|
|
||||||
messages: Union[
|
|
||||||
MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None
|
|
||||||
] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[
|
|
||||||
BaseMessage,
|
|
||||||
List[BaseMessage],
|
|
||||||
Runnable[
|
|
||||||
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
|
||||||
Union[BaseMessage, List[BaseMessage]],
|
|
||||||
],
|
|
||||||
]:
|
|
||||||
from langchain_core.runnables.base import RunnableGenerator
|
|
||||||
|
|
||||||
if messages is not None:
|
|
||||||
return func(messages, **kwargs)
|
|
||||||
else:
|
|
||||||
def transform(input_: Iterator, **kwargs: Any) -> Iterator:
|
|
||||||
for x in input_:
|
|
||||||
yield func(x, **kwargs)
|
|
||||||
return RunnableGenerator(partial(transform, **kwargs), name=func.__name__)
|
|
||||||
|
|
||||||
wrapped.__doc__ = func.__doc__
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
def _runnable_support(func: Callable) -> Callable:
|
def _runnable_support(func: Callable) -> Callable:
|
||||||
@overload
|
@overload
|
||||||
def wrapped(
|
def wrapped(
|
||||||
@ -903,13 +862,84 @@ def trim_messages(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _runnable_generator(func: Callable) -> Callable:
|
||||||
|
@overload
|
||||||
|
def wrapped(
|
||||||
|
messages: Literal[None] = None, **kwargs: Any
|
||||||
|
) -> Runnable[
|
||||||
|
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||||
|
Union[BaseMessage, List[BaseMessage]],
|
||||||
|
]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def wrapped(
|
||||||
|
messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any
|
||||||
|
) -> List[BaseMessage]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ...
|
||||||
|
|
||||||
|
def wrapped(
|
||||||
|
messages: Union[
|
||||||
|
MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None
|
||||||
|
] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[
|
||||||
|
BaseMessage,
|
||||||
|
List[BaseMessage],
|
||||||
|
Runnable[
|
||||||
|
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||||
|
Union[BaseMessage, List[BaseMessage]],
|
||||||
|
],
|
||||||
|
]:
|
||||||
|
from langchain_core.runnables.base import RunnableGenerator
|
||||||
|
|
||||||
|
if messages is not None:
|
||||||
|
return func(messages, **kwargs)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def transform(input_: Iterator, **kwargs: Any) -> Iterator:
|
||||||
|
block_indexes = set()
|
||||||
|
for x in input_:
|
||||||
|
msg = func(x, **kwargs)
|
||||||
|
# Special handling for transforming an OpenAI stream to an
|
||||||
|
# Anthropic stream.
|
||||||
|
if isinstance(msg, AIMessageChunk) and isinstance(
|
||||||
|
msg.content, list
|
||||||
|
):
|
||||||
|
tool_use_ct = 0
|
||||||
|
for block in msg.content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
if "index" in block:
|
||||||
|
block_indexes.add(block["index"])
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
block["index"] = max(len(block_indexes), *block_indexes)
|
||||||
|
msg.tool_call_chunks[tool_use_ct]["index"] = block[
|
||||||
|
"index"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if block.get("type") == "tool_use":
|
||||||
|
tool_use_ct += 1
|
||||||
|
else:
|
||||||
|
block_indexes.add(0)
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
return RunnableGenerator(partial(transform, **kwargs), name=func.__name__)
|
||||||
|
|
||||||
|
wrapped.__doc__ = func.__doc__
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
@_runnable_generator
|
@_runnable_generator
|
||||||
def format_messages_as(
|
def format_messages_as(
|
||||||
messages: Union[MessageLikeRepresentation, Iterator[MessageLikeRepresentation]],
|
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||||
*,
|
*,
|
||||||
format: Literal["openai", "anthropic"],
|
format: Literal["openai", "anthropic"],
|
||||||
text: Literal["string", "block"],
|
text: Literal["string", "block"],
|
||||||
) -> Union[BaseMessage, Iterator[BaseMessage]]:
|
) -> Union[BaseMessage, List[BaseMessage]]:
|
||||||
"""Convert message contents into a standard format.
|
"""Convert message contents into a standard format.
|
||||||
|
|
||||||
.. versionadded:: 0.2.36
|
.. versionadded:: 0.2.36
|
||||||
@ -975,61 +1005,63 @@ def format_messages_as(
|
|||||||
from langchain_core.messages import format_messages_as
|
from langchain_core.messages import format_messages_as
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
formatter = format_messages_as(format="openai", text="block")
|
formatter = format_messages_as(format="openai", text="string")
|
||||||
llm = init_chat_model() | formatter
|
llm = init_chat_model() | formatter
|
||||||
|
|
||||||
llm.invoke(
|
llm.invoke(
|
||||||
[{"role": "user", "content": "how are you"}],
|
[{"role": "user", "content": "how are you"}],
|
||||||
config={"model": "gpt-4o"},
|
config={"model": "gpt-4o"},
|
||||||
)
|
)
|
||||||
# -> AIMessage([{"type": "text", "text": ""}], ...)
|
# -> AIMessage(["I am good..."], ...)
|
||||||
|
|
||||||
llm.invoke(
|
llm.invoke(
|
||||||
[{"role": "user", "content": "whats your name"}],
|
[{"role": "user", "content": "whats your name"}],
|
||||||
config={"model": "claude-3-5-sonnet-20240620"})
|
config={"model": "claude-3-5-sonnet-20240620"}
|
||||||
# -> AIMessage([{"type": "text", "text": ""}], ...)
|
)
|
||||||
|
# -> AIMessage(["My name is...], ...)
|
||||||
|
|
||||||
.. note:: Doesn't support streaming
|
|
||||||
|
|
||||||
This util does not support formatting streamed chunks on the fly (i.e.
|
def multiply(a: int, b: int) -> int:
|
||||||
"transforming" chunks). This means if you pipe the outputs of a model to this
|
'''Return product of a and b.'''
|
||||||
formatter in a chain, the chain will not have token-level streaming when
|
return a * b
|
||||||
using ``chain.stream()/.astream()``. You'll still see the
|
|
||||||
token stream when using ``chat.astream_events()`` but the message chunks will
|
|
||||||
not yet be formatted.
|
|
||||||
|
|
||||||
.. code-block:: python
|
llm_with_tools = init_chat_model().bind_tools([multiply]) | formatter
|
||||||
|
|
||||||
from langchain_core.messages import format_messages_as
|
for chunk in llm_with_tools.stream(
|
||||||
from langchain.chat_models import init_chat_model
|
"what's 5 times 2", config={"model": "claude-3-5-sonnet-20240620"}
|
||||||
|
|
||||||
formatter = format_messages_as(format="openai", text="block")
|
|
||||||
llm = init_chat_model() | formatter
|
|
||||||
|
|
||||||
# Will contain a single, completed chunk.
|
|
||||||
list(llm.stream(
|
|
||||||
[{"role": "user", "content": "how are you"}],
|
|
||||||
config={"model": "gpt-4o"},
|
|
||||||
))
|
|
||||||
|
|
||||||
# Will include token-level events, but the streamed chunks will not yet be
|
|
||||||
# formatted.
|
|
||||||
async for chunk in llm.astream_events(
|
|
||||||
[{"role": "user", "content": "how are you"}],
|
|
||||||
config={"model": "gpt-4o"},
|
|
||||||
version="v2",
|
|
||||||
):
|
):
|
||||||
...
|
print(chunk)
|
||||||
|
# -> AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', usage_metadata={'input_tokens': 370, 'output_tokens': 0, 'total_tokens': 370}),
|
||||||
|
# AIMessageChunk(content='Certainly', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content='! To', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' calculate', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' 5 times ', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content='2, we can use', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' the "', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content='multiply" function that', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content="'s", id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' available to', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' us.', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' Let', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content="'s use", id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' this tool', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' to', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' get', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content=' the result.', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||||
|
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': 'multiply', 'args': {}, 'id': 'toolu_01PW8o6BkATCecjsJX8QgG6z', 'type': 'tool_call'}], tool_call_chunks=[{'name': 'multiply', 'args': '', 'id': 'toolu_01PW8o6BkATCecjsJX8QgG6z', 'index': 1, 'type': 'tool_call_chunk'}]),
|
||||||
|
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': '', 'args': {}, 'id': None, 'type': 'tool_call'}], tool_call_chunks=[{'name': None, 'args': '', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
|
||||||
|
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': '', 'args': {'a': 5}, 'id': None, 'type': 'tool_call'}], tool_call_chunks=[{'name': None, 'args': '{"a": 5', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
|
||||||
|
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', invalid_tool_calls=[{'name': None, 'args': ', "b": 2}', 'id': None, 'error': None, 'type': 'invalid_tool_call'}], tool_call_chunks=[{'name': None, 'args': ', "b": 2}', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
|
||||||
|
# AIMessageChunk(content='', response_metadata={'stop_reason': 'tool_use', 'stop_sequence': None}, id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', usage_metadata={'input_tokens': 0, 'output_tokens': 104, 'total_tokens': 104})
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
if is_single := isinstance(messages, (BaseMessage, dict)):
|
if is_single := isinstance(messages, (BaseMessage, dict)):
|
||||||
messages = [messages]
|
messages = [messages]
|
||||||
messages = convert_to_messages(messages, copy=True)
|
messages = convert_to_messages(messages, copy=True)
|
||||||
if format.lower() == "openai":
|
if format.lower() == "openai":
|
||||||
formatted = _format_contents_as_openai(messages, text=text)
|
formatted = _format_messages_as_openai(messages, text=text)
|
||||||
elif format.lower() == "anthropic":
|
elif format.lower() == "anthropic":
|
||||||
formatted = _format_contents_as_anthropic(messages, text=text)
|
formatted = _format_messages_as_anthropic(messages, text=text)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized {format=}. Expected one of ('openai', 'anthropic')."
|
f"Unrecognized {format=}. Expected one of ('openai', 'anthropic')."
|
||||||
@ -1040,7 +1072,7 @@ def format_messages_as(
|
|||||||
return formatted
|
return formatted
|
||||||
|
|
||||||
|
|
||||||
def _format_contents_as_openai(
|
def _format_messages_as_openai(
|
||||||
messages: Sequence[BaseMessage], *, text: Literal["string", "block"]
|
messages: Sequence[BaseMessage], *, text: Literal["string", "block"]
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
"""Mutates messages so their contents match OpenAI messages API."""
|
"""Mutates messages so their contents match OpenAI messages API."""
|
||||||
@ -1146,14 +1178,14 @@ def _format_contents_as_openai(
|
|||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
elif block.get("type") == "tool_use":
|
elif block.get("type") == "tool_use":
|
||||||
if not isinstance(message, BaseMessageChunk):
|
if not isinstance(message, AIMessageChunk):
|
||||||
if missing := [
|
if missing := [
|
||||||
k for k in ("id", "name", "input") if k not in block
|
k for k in ("id", "name", "input") if k not in block
|
||||||
]:
|
]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': 'tool_use', "
|
f"messages[{i}].content[{j}] has 'type': "
|
||||||
f"but is missing expected key(s) "
|
f"'tool_use', but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
if not any(
|
if not any(
|
||||||
@ -1169,7 +1201,14 @@ def _format_contents_as_openai(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not message.tool_call_chunks:
|
if not message.tool_call_chunks:
|
||||||
message.tool_call_chunks = [create_tool_call_chunk(id=block.get("id"), index=block.get("index"), args=block.get("partial_json"), name=block.get("name"))]
|
message.tool_call_chunks = [
|
||||||
|
create_tool_call_chunk(
|
||||||
|
id=block.get("id"),
|
||||||
|
index=block.get("index"),
|
||||||
|
args=block.get("partial_json"),
|
||||||
|
name=block.get("name"),
|
||||||
|
)
|
||||||
|
]
|
||||||
elif block.get("type") == "tool_result":
|
elif block.get("type") == "tool_result":
|
||||||
if missing := [
|
if missing := [
|
||||||
k for k in ("content", "tool_use_id") if k not in block
|
k for k in ("content", "tool_use_id") if k not in block
|
||||||
@ -1187,7 +1226,7 @@ def _format_contents_as_openai(
|
|||||||
)
|
)
|
||||||
# Recurse to make sure tool message contents are OpenAI format.
|
# Recurse to make sure tool message contents are OpenAI format.
|
||||||
tool_messages.extend(
|
tool_messages.extend(
|
||||||
_format_contents_as_openai([tool_message], text=text)
|
_format_messages_as_openai([tool_message], text=text)
|
||||||
)
|
)
|
||||||
elif (block.get("type") == "json") or "json" in block:
|
elif (block.get("type") == "json") or "json" in block:
|
||||||
if "json" not in block:
|
if "json" not in block:
|
||||||
@ -1253,15 +1292,20 @@ def _format_contents_as_openai(
|
|||||||
f"Anthropic, Bedrock Converse, or VertexAI format. Full "
|
f"Anthropic, Bedrock Converse, or VertexAI format. Full "
|
||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
message.content = content # type: ignore[assignment]
|
if text == "string" and not any(
|
||||||
|
block["type"] != "text" for block in content
|
||||||
|
):
|
||||||
|
message.content = "\n".join(block["text"] for block in content)
|
||||||
|
else:
|
||||||
|
message.content = content # type: ignore[assignment]
|
||||||
updated_messages.extend([message, *tool_messages])
|
updated_messages.extend([message, *tool_messages])
|
||||||
return updated_messages
|
return updated_messages
|
||||||
|
|
||||||
|
|
||||||
_OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error")
|
_OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error", "index")
|
||||||
|
|
||||||
|
|
||||||
def _format_contents_as_anthropic(
|
def _format_messages_as_anthropic(
|
||||||
messages: Sequence[BaseMessage], *, text: Literal["string", "block"]
|
messages: Sequence[BaseMessage], *, text: Literal["string", "block"]
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
"""Mutates messages so their contents match Anthropic messages API."""
|
"""Mutates messages so their contents match Anthropic messages API."""
|
||||||
@ -1289,7 +1333,10 @@ def _format_contents_as_anthropic(
|
|||||||
if text == "string":
|
if text == "string":
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
message.content = [{"type": "text", "text": message.content}]
|
text_block: dict = {"type": "text", "text": message.content}
|
||||||
|
if isinstance(message, AIMessageChunk):
|
||||||
|
text_block["index"] = 0
|
||||||
|
message.content = [text_block]
|
||||||
else:
|
else:
|
||||||
if text == "string" and all(
|
if text == "string" and all(
|
||||||
isinstance(block, str)
|
isinstance(block, str)
|
||||||
@ -1303,13 +1350,20 @@ def _format_contents_as_anthropic(
|
|||||||
else:
|
else:
|
||||||
content = []
|
content = []
|
||||||
for j, block in enumerate(message.content):
|
for j, block in enumerate(message.content):
|
||||||
# OpenAI format
|
if isinstance(block, dict):
|
||||||
if isinstance(block, str):
|
|
||||||
content.append({"type": "text", "text": block})
|
|
||||||
elif block.get("type") == "text":
|
|
||||||
block_extra = {
|
block_extra = {
|
||||||
k: block[k] for k in _OPTIONAL_ANTHROPIC_KEYS if k in block
|
k: block[k] for k in _OPTIONAL_ANTHROPIC_KEYS if k in block
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
block_extra = {}
|
||||||
|
|
||||||
|
# OpenAI format
|
||||||
|
if isinstance(block, str):
|
||||||
|
text_block = {"type": "text", "text": block}
|
||||||
|
if isinstance(message, AIMessageChunk):
|
||||||
|
text_block["index"] = 0
|
||||||
|
content.append(text_block)
|
||||||
|
elif block.get("type") == "text":
|
||||||
if missing := [k for k in ("text",) if k not in block]:
|
if missing := [k for k in ("text",) if k not in block]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
@ -1380,35 +1434,72 @@ def _format_contents_as_anthropic(
|
|||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
elif block.get("type") == "tool_use":
|
elif block.get("type") == "tool_use":
|
||||||
if missing := [
|
if not isinstance(message, AIMessageChunk):
|
||||||
k for k in ("id", "name", "input") if k not in block
|
if missing := [
|
||||||
]:
|
k for k in ("id", "name", "input") if k not in block
|
||||||
raise ValueError(
|
]:
|
||||||
f"Unrecognized content block at "
|
raise ValueError(
|
||||||
f"messages[{i}].content[{j}] has 'type': 'tool_use', "
|
f"Unrecognized content block at "
|
||||||
f"but is missing expected key(s) "
|
f"messages[{i}].content[{j}] has 'type': "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"'tool_use', "
|
||||||
)
|
f"but is missing expected key(s) "
|
||||||
content.append(
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"name": block["name"],
|
|
||||||
"id": block["id"],
|
|
||||||
"input": block["input"],
|
|
||||||
**block_extra,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if not any(
|
|
||||||
tool_call["id"] == block["id"]
|
|
||||||
for tool_call in cast(AIMessage, message).tool_calls
|
|
||||||
):
|
|
||||||
cast(AIMessage, message).tool_calls.append(
|
|
||||||
create_tool_call(
|
|
||||||
name=block["name"],
|
|
||||||
id=block["id"],
|
|
||||||
args=block["input"],
|
|
||||||
)
|
)
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": block["name"],
|
||||||
|
"id": block["id"],
|
||||||
|
"input": block["input"],
|
||||||
|
**block_extra,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
if not any(
|
||||||
|
tool_call["id"] == block["id"]
|
||||||
|
for tool_call in cast(AIMessage, message).tool_calls
|
||||||
|
):
|
||||||
|
cast(AIMessage, message).tool_calls.append(
|
||||||
|
create_tool_call(
|
||||||
|
name=block["name"],
|
||||||
|
id=block["id"],
|
||||||
|
args=block["input"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
not any(k in block for k in ("input", "partial_json"))
|
||||||
|
or "index" not in block
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized content block at "
|
||||||
|
f"message_chunks[{i}].content[{j}] has "
|
||||||
|
f"'type': 'tool_use', "
|
||||||
|
f"but is does not have either an 'input' or "
|
||||||
|
f"'partial_json' and an 'index' key. Full content "
|
||||||
|
f"block:\n\n{block}"
|
||||||
|
)
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"index": block["index"],
|
||||||
|
**{
|
||||||
|
k: block[k]
|
||||||
|
for k in ("name", "input", "partial_json", "id")
|
||||||
|
if k in block
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if not message.tool_call_chunks:
|
||||||
|
message.tool_call_chunks = [
|
||||||
|
create_tool_call_chunk(
|
||||||
|
name=block.get("name"),
|
||||||
|
id=block.get("id"),
|
||||||
|
index=block["index"],
|
||||||
|
args=block["partial_json"]
|
||||||
|
if "partial_json" in block
|
||||||
|
else block["input"],
|
||||||
|
)
|
||||||
|
]
|
||||||
elif block.get("type") == "tool_result":
|
elif block.get("type") == "tool_result":
|
||||||
if missing := [
|
if missing := [
|
||||||
k for k in ("content", "tool_use_id") if k not in block
|
k for k in ("content", "tool_use_id") if k not in block
|
||||||
@ -1489,7 +1580,35 @@ def _format_contents_as_anthropic(
|
|||||||
)
|
)
|
||||||
message.content = content # type: ignore[assignment]
|
message.content = content # type: ignore[assignment]
|
||||||
|
|
||||||
if isinstance(message, AIMessage) and message.tool_calls:
|
if isinstance(message, AIMessageChunk) and message.tool_call_chunks:
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
if message.content:
|
||||||
|
message.content = [
|
||||||
|
{"type": "text", "text": message.content, "index": 0}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
message.content = []
|
||||||
|
if not any(
|
||||||
|
cast(dict, block).get("type") == "tool_use" for block in message.content
|
||||||
|
):
|
||||||
|
tool_use_blocks = [
|
||||||
|
# Note: we intentionally omit index so that it can be set by the
|
||||||
|
# stream handler, which can count how many blocks
|
||||||
|
# have been seen in preceding chunks
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"partial_json": tc_chunk["args"],
|
||||||
|
"id": tc_chunk["id"],
|
||||||
|
"name": tc_chunk["name"],
|
||||||
|
}
|
||||||
|
for i, tc_chunk in enumerate(message.tool_call_chunks)
|
||||||
|
]
|
||||||
|
tool_use_blocks = [
|
||||||
|
{k: v for k, v in tu_block.items() if v is not None}
|
||||||
|
for tu_block in tool_use_blocks
|
||||||
|
]
|
||||||
|
message.content.extend(tool_use_blocks)
|
||||||
|
elif isinstance(message, AIMessage) and message.tool_calls:
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
message.content = [{"type": "text", "text": message.content}]
|
message.content = [{"type": "text", "text": message.content}]
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, List, Type
|
from typing import Any, Callable, Dict, Iterator, List, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
@ -20,6 +21,7 @@ from langchain_core.messages.utils import (
|
|||||||
merge_message_runs,
|
merge_message_runs,
|
||||||
trim_messages,
|
trim_messages,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables import RunnableLambda
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
|
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
|
||||||
@ -810,3 +812,580 @@ def test_format_messages_as_declarative() -> None:
|
|||||||
result = formatter.invoke(messages)
|
result = formatter.invoke(messages)
|
||||||
assert result[0].content[1]["type"] == "image_url"
|
assert result[0].content[1]["type"] == "image_url"
|
||||||
assert result[0].content[1]["image_url"]["url"] == base64_image
|
assert result[0].content[1]["image_url"]["url"] == base64_image
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_oai(input_: Any) -> Iterator:
|
||||||
|
chunks = [
|
||||||
|
AIMessageChunk(content=""),
|
||||||
|
AIMessageChunk(content="Certainly"),
|
||||||
|
AIMessageChunk(content="!"),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": {},
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": "",
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"index": 0,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": '{"',
|
||||||
|
"id": None,
|
||||||
|
"index": 0,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"index": 0,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"index": 0,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
yield from chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_anthropic(input_: Any) -> Iterator:
|
||||||
|
chunks = [
|
||||||
|
AIMessageChunk(content=[]),
|
||||||
|
AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]),
|
||||||
|
AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"input": {},
|
||||||
|
"name": "multiply",
|
||||||
|
"type": "tool_use",
|
||||||
|
"index": 1,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": {},
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": "",
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[{"partial_json": '{"', "type": "tool_use", "index": 1}],
|
||||||
|
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": '{"',
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[{"partial_json": "}", "type": "tool_use", "index": 1}],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(content=""),
|
||||||
|
]
|
||||||
|
yield from chunks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
|
||||||
|
def test_format_messages_openai_string_stream(stream: Callable) -> None:
|
||||||
|
formatter = format_messages_as(format="openai", text="string")
|
||||||
|
|
||||||
|
chain = RunnableLambda(stream) | formatter
|
||||||
|
tool_call_idx = 1 if stream == _stream_anthropic else 0
|
||||||
|
expected = [
|
||||||
|
AIMessageChunk(content=""),
|
||||||
|
AIMessageChunk(content="Certainly"),
|
||||||
|
AIMessageChunk(content="!"),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": {},
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": "",
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"index": tool_call_idx,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": '{"',
|
||||||
|
"id": None,
|
||||||
|
"index": tool_call_idx,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"index": tool_call_idx,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"index": tool_call_idx,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
actual = list(chain.stream({}))
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
|
||||||
|
def test_format_messages_openai_block_stream(stream: Callable) -> None:
|
||||||
|
formatter = format_messages_as(format="openai", text="block")
|
||||||
|
|
||||||
|
chain = RunnableLambda(stream) | formatter
|
||||||
|
tool_call_idx = 1 if stream == _stream_anthropic else 0
|
||||||
|
expected = [
|
||||||
|
AIMessageChunk(content=[]),
|
||||||
|
AIMessageChunk(content=[{"type": "text", "text": "Certainly"}]),
|
||||||
|
AIMessageChunk(content=[{"type": "text", "text": "!"}]),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": {},
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": "",
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"index": tool_call_idx,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[],
|
||||||
|
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": '{"',
|
||||||
|
"id": None,
|
||||||
|
"index": tool_call_idx,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"index": tool_call_idx,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"index": tool_call_idx,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
actual = list(chain.stream({}))
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
|
||||||
|
def test_format_messages_anthropic_block_stream(stream: Callable) -> None:
|
||||||
|
formatter = format_messages_as(format="anthropic", text="block")
|
||||||
|
|
||||||
|
chain = RunnableLambda(stream) | formatter
|
||||||
|
expected = [
|
||||||
|
AIMessageChunk(content=[]),
|
||||||
|
AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]),
|
||||||
|
AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"name": "multiply",
|
||||||
|
"type": "tool_use",
|
||||||
|
"index": 1,
|
||||||
|
**(
|
||||||
|
{"input": {}}
|
||||||
|
if stream == _stream_anthropic
|
||||||
|
else {"partial_json": ""}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": {},
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": "",
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[{"partial_json": '{"', "type": "tool_use", "index": 1}],
|
||||||
|
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": '{"',
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[{"partial_json": "}", "type": "tool_use", "index": 1}],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(content=[]),
|
||||||
|
]
|
||||||
|
actual = list(chain.stream({}))
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
|
||||||
|
def test_format_messages_anthropic_string_stream(stream: Callable) -> None:
|
||||||
|
formatter = format_messages_as(format="anthropic", text="string")
|
||||||
|
|
||||||
|
chain = RunnableLambda(stream) | formatter
|
||||||
|
expected = [
|
||||||
|
AIMessageChunk(content=""),
|
||||||
|
AIMessageChunk(content="Certainly"),
|
||||||
|
AIMessageChunk(content="!"),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"name": "multiply",
|
||||||
|
"index": 1,
|
||||||
|
**(
|
||||||
|
{"input": {}}
|
||||||
|
if stream == _stream_anthropic
|
||||||
|
else {"partial_json": ""}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": {},
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"args": "",
|
||||||
|
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "tool_use", "partial_json": '{"', "index": 1},
|
||||||
|
],
|
||||||
|
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": '{"',
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "tool_use", "partial_json": 'a": 5, "b": 2', "index": 1},
|
||||||
|
],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": 'a": 5, "b": 2',
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "tool_use", "partial_json": "}", "index": 1},
|
||||||
|
],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"error": None,
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"args": "}",
|
||||||
|
"id": None,
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessageChunk(content=""),
|
||||||
|
]
|
||||||
|
actual = list(chain.stream({}))
|
||||||
|
assert expected == actual
|
||||||
|
@ -31,7 +31,7 @@ from langchain_core.language_models.chat_models import (
|
|||||||
)
|
)
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage
|
from langchain_core.messages import AnyMessage, BaseMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables import Runnable, RunnableConfig
|
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.tracers import RunLog, RunLogPatch
|
from langchain_core.tracers import RunLog, RunLogPatch
|
||||||
@ -512,7 +512,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
|
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
model_params = {
|
model_params = {
|
||||||
_remove_prefix(k, self._config_prefix): v
|
_remove_prefix(k, self._config_prefix): v
|
||||||
for k, v in config.get("configurable", {}).items()
|
for k, v in config.get("configurable", {}).items()
|
||||||
|
Loading…
Reference in New Issue
Block a user