mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-24 03:52:08 +00:00
fmt
This commit is contained in:
parent
fb002faba4
commit
992884b410
@ -20,6 +20,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
@ -28,7 +29,7 @@ from typing import (
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
overload, Iterator,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||
@ -41,7 +42,9 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
tool_call_chunk as create_tool_call_chunk
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call_chunk as create_tool_call_chunk,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -335,50 +338,6 @@ def convert_to_messages(
|
||||
return [_convert_to_message(m, copy=copy) for m in messages]
|
||||
|
||||
|
||||
def _runnable_generator(func: Callable) -> Callable:
|
||||
@overload
|
||||
def wrapped(
|
||||
messages: Literal[None] = None, **kwargs: Any
|
||||
) -> Runnable[
|
||||
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||
Union[BaseMessage, List[BaseMessage]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def wrapped(
|
||||
messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any
|
||||
) -> List[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ...
|
||||
|
||||
def wrapped(
|
||||
messages: Union[
|
||||
MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
BaseMessage,
|
||||
List[BaseMessage],
|
||||
Runnable[
|
||||
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||
Union[BaseMessage, List[BaseMessage]],
|
||||
],
|
||||
]:
|
||||
from langchain_core.runnables.base import RunnableGenerator
|
||||
|
||||
if messages is not None:
|
||||
return func(messages, **kwargs)
|
||||
else:
|
||||
def transform(input_: Iterator, **kwargs: Any) -> Iterator:
|
||||
for x in input_:
|
||||
yield func(x, **kwargs)
|
||||
return RunnableGenerator(partial(transform, **kwargs), name=func.__name__)
|
||||
|
||||
wrapped.__doc__ = func.__doc__
|
||||
return wrapped
|
||||
|
||||
|
||||
def _runnable_support(func: Callable) -> Callable:
|
||||
@overload
|
||||
def wrapped(
|
||||
@ -903,13 +862,84 @@ def trim_messages(
|
||||
)
|
||||
|
||||
|
||||
def _runnable_generator(func: Callable) -> Callable:
|
||||
@overload
|
||||
def wrapped(
|
||||
messages: Literal[None] = None, **kwargs: Any
|
||||
) -> Runnable[
|
||||
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||
Union[BaseMessage, List[BaseMessage]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def wrapped(
|
||||
messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any
|
||||
) -> List[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ...
|
||||
|
||||
def wrapped(
|
||||
messages: Union[
|
||||
MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
BaseMessage,
|
||||
List[BaseMessage],
|
||||
Runnable[
|
||||
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||
Union[BaseMessage, List[BaseMessage]],
|
||||
],
|
||||
]:
|
||||
from langchain_core.runnables.base import RunnableGenerator
|
||||
|
||||
if messages is not None:
|
||||
return func(messages, **kwargs)
|
||||
else:
|
||||
|
||||
def transform(input_: Iterator, **kwargs: Any) -> Iterator:
|
||||
block_indexes = set()
|
||||
for x in input_:
|
||||
msg = func(x, **kwargs)
|
||||
# Special handling for transforming an OpenAI stream to an
|
||||
# Anthropic stream.
|
||||
if isinstance(msg, AIMessageChunk) and isinstance(
|
||||
msg.content, list
|
||||
):
|
||||
tool_use_ct = 0
|
||||
for block in msg.content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if "index" in block:
|
||||
block_indexes.add(block["index"])
|
||||
elif block.get("type") == "tool_use":
|
||||
block["index"] = max(len(block_indexes), *block_indexes)
|
||||
msg.tool_call_chunks[tool_use_ct]["index"] = block[
|
||||
"index"
|
||||
]
|
||||
else:
|
||||
pass
|
||||
|
||||
if block.get("type") == "tool_use":
|
||||
tool_use_ct += 1
|
||||
else:
|
||||
block_indexes.add(0)
|
||||
yield msg
|
||||
|
||||
return RunnableGenerator(partial(transform, **kwargs), name=func.__name__)
|
||||
|
||||
wrapped.__doc__ = func.__doc__
|
||||
return wrapped
|
||||
|
||||
|
||||
@_runnable_generator
|
||||
def format_messages_as(
|
||||
messages: Union[MessageLikeRepresentation, Iterator[MessageLikeRepresentation]],
|
||||
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||
*,
|
||||
format: Literal["openai", "anthropic"],
|
||||
text: Literal["string", "block"],
|
||||
) -> Union[BaseMessage, Iterator[BaseMessage]]:
|
||||
) -> Union[BaseMessage, List[BaseMessage]]:
|
||||
"""Convert message contents into a standard format.
|
||||
|
||||
.. versionadded:: 0.2.36
|
||||
@ -975,61 +1005,63 @@ def format_messages_as(
|
||||
from langchain_core.messages import format_messages_as
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
formatter = format_messages_as(format="openai", text="block")
|
||||
formatter = format_messages_as(format="openai", text="string")
|
||||
llm = init_chat_model() | formatter
|
||||
|
||||
llm.invoke(
|
||||
[{"role": "user", "content": "how are you"}],
|
||||
config={"model": "gpt-4o"},
|
||||
)
|
||||
# -> AIMessage([{"type": "text", "text": ""}], ...)
|
||||
# -> AIMessage(["I am good..."], ...)
|
||||
|
||||
llm.invoke(
|
||||
[{"role": "user", "content": "whats your name"}],
|
||||
config={"model": "claude-3-5-sonnet-20240620"})
|
||||
# -> AIMessage([{"type": "text", "text": ""}], ...)
|
||||
config={"model": "claude-3-5-sonnet-20240620"}
|
||||
)
|
||||
# -> AIMessage(["My name is...], ...)
|
||||
|
||||
.. note:: Doesn't support streaming
|
||||
|
||||
This util does not support formatting streamed chunks on the fly (i.e.
|
||||
"transforming" chunks). This means if you pipe the outputs of a model to this
|
||||
formatter in a chain, the chain will not have token-level streaming when
|
||||
using ``chain.stream()/.astream()``. You'll still see the
|
||||
token stream when using ``chat.astream_events()`` but the message chunks will
|
||||
not yet be formatted.
|
||||
def multiply(a: int, b: int) -> int:
|
||||
'''Return product of a and b.'''
|
||||
return a * b
|
||||
|
||||
.. code-block:: python
|
||||
llm_with_tools = init_chat_model().bind_tools([multiply]) | formatter
|
||||
|
||||
from langchain_core.messages import format_messages_as
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
formatter = format_messages_as(format="openai", text="block")
|
||||
llm = init_chat_model() | formatter
|
||||
|
||||
# Will contain a single, completed chunk.
|
||||
list(llm.stream(
|
||||
[{"role": "user", "content": "how are you"}],
|
||||
config={"model": "gpt-4o"},
|
||||
))
|
||||
|
||||
# Will include token-level events, but the streamed chunks will not yet be
|
||||
# formatted.
|
||||
async for chunk in llm.astream_events(
|
||||
[{"role": "user", "content": "how are you"}],
|
||||
config={"model": "gpt-4o"},
|
||||
version="v2",
|
||||
for chunk in llm_with_tools.stream(
|
||||
"what's 5 times 2", config={"model": "claude-3-5-sonnet-20240620"}
|
||||
):
|
||||
...
|
||||
|
||||
print(chunk)
|
||||
# -> AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', usage_metadata={'input_tokens': 370, 'output_tokens': 0, 'total_tokens': 370}),
|
||||
# AIMessageChunk(content='Certainly', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content='! To', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' calculate', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' 5 times ', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content='2, we can use', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' the "', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content='multiply" function that', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content="'s", id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' available to', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' us.', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' Let', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content="'s use", id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' this tool', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' to', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' get', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content=' the result.', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75'),
|
||||
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': 'multiply', 'args': {}, 'id': 'toolu_01PW8o6BkATCecjsJX8QgG6z', 'type': 'tool_call'}], tool_call_chunks=[{'name': 'multiply', 'args': '', 'id': 'toolu_01PW8o6BkATCecjsJX8QgG6z', 'index': 1, 'type': 'tool_call_chunk'}]),
|
||||
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': '', 'args': {}, 'id': None, 'type': 'tool_call'}], tool_call_chunks=[{'name': None, 'args': '', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
|
||||
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', tool_calls=[{'name': '', 'args': {'a': 5}, 'id': None, 'type': 'tool_call'}], tool_call_chunks=[{'name': None, 'args': '{"a": 5', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
|
||||
# AIMessageChunk(content='', id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', invalid_tool_calls=[{'name': None, 'args': ', "b": 2}', 'id': None, 'error': None, 'type': 'invalid_tool_call'}], tool_call_chunks=[{'name': None, 'args': ', "b": 2}', 'id': None, 'index': 1, 'type': 'tool_call_chunk'}]),
|
||||
# AIMessageChunk(content='', response_metadata={'stop_reason': 'tool_use', 'stop_sequence': None}, id='run-64757cb2-b85f-4d51-8f34-5a6c1d40ad75', usage_metadata={'input_tokens': 0, 'output_tokens': 104, 'total_tokens': 104})
|
||||
|
||||
""" # noqa: E501
|
||||
if is_single := isinstance(messages, (BaseMessage, dict)):
|
||||
messages = [messages]
|
||||
messages = convert_to_messages(messages, copy=True)
|
||||
if format.lower() == "openai":
|
||||
formatted = _format_contents_as_openai(messages, text=text)
|
||||
formatted = _format_messages_as_openai(messages, text=text)
|
||||
elif format.lower() == "anthropic":
|
||||
formatted = _format_contents_as_anthropic(messages, text=text)
|
||||
formatted = _format_messages_as_anthropic(messages, text=text)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized {format=}. Expected one of ('openai', 'anthropic')."
|
||||
@ -1040,7 +1072,7 @@ def format_messages_as(
|
||||
return formatted
|
||||
|
||||
|
||||
def _format_contents_as_openai(
|
||||
def _format_messages_as_openai(
|
||||
messages: Sequence[BaseMessage], *, text: Literal["string", "block"]
|
||||
) -> List[BaseMessage]:
|
||||
"""Mutates messages so their contents match OpenAI messages API."""
|
||||
@ -1146,14 +1178,14 @@ def _format_contents_as_openai(
|
||||
f"content block:\n\n{block}"
|
||||
)
|
||||
elif block.get("type") == "tool_use":
|
||||
if not isinstance(message, BaseMessageChunk):
|
||||
if not isinstance(message, AIMessageChunk):
|
||||
if missing := [
|
||||
k for k in ("id", "name", "input") if k not in block
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unrecognized content block at "
|
||||
f"messages[{i}].content[{j}] has 'type': 'tool_use', "
|
||||
f"but is missing expected key(s) "
|
||||
f"messages[{i}].content[{j}] has 'type': "
|
||||
f"'tool_use', but is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
if not any(
|
||||
@ -1169,7 +1201,14 @@ def _format_contents_as_openai(
|
||||
)
|
||||
else:
|
||||
if not message.tool_call_chunks:
|
||||
message.tool_call_chunks = [create_tool_call_chunk(id=block.get("id"), index=block.get("index"), args=block.get("partial_json"), name=block.get("name"))]
|
||||
message.tool_call_chunks = [
|
||||
create_tool_call_chunk(
|
||||
id=block.get("id"),
|
||||
index=block.get("index"),
|
||||
args=block.get("partial_json"),
|
||||
name=block.get("name"),
|
||||
)
|
||||
]
|
||||
elif block.get("type") == "tool_result":
|
||||
if missing := [
|
||||
k for k in ("content", "tool_use_id") if k not in block
|
||||
@ -1187,7 +1226,7 @@ def _format_contents_as_openai(
|
||||
)
|
||||
# Recurse to make sure tool message contents are OpenAI format.
|
||||
tool_messages.extend(
|
||||
_format_contents_as_openai([tool_message], text=text)
|
||||
_format_messages_as_openai([tool_message], text=text)
|
||||
)
|
||||
elif (block.get("type") == "json") or "json" in block:
|
||||
if "json" not in block:
|
||||
@ -1253,15 +1292,20 @@ def _format_contents_as_openai(
|
||||
f"Anthropic, Bedrock Converse, or VertexAI format. Full "
|
||||
f"content block:\n\n{block}"
|
||||
)
|
||||
message.content = content # type: ignore[assignment]
|
||||
if text == "string" and not any(
|
||||
block["type"] != "text" for block in content
|
||||
):
|
||||
message.content = "\n".join(block["text"] for block in content)
|
||||
else:
|
||||
message.content = content # type: ignore[assignment]
|
||||
updated_messages.extend([message, *tool_messages])
|
||||
return updated_messages
|
||||
|
||||
|
||||
_OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error")
|
||||
_OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error", "index")
|
||||
|
||||
|
||||
def _format_contents_as_anthropic(
|
||||
def _format_messages_as_anthropic(
|
||||
messages: Sequence[BaseMessage], *, text: Literal["string", "block"]
|
||||
) -> List[BaseMessage]:
|
||||
"""Mutates messages so their contents match Anthropic messages API."""
|
||||
@ -1289,7 +1333,10 @@ def _format_contents_as_anthropic(
|
||||
if text == "string":
|
||||
pass
|
||||
else:
|
||||
message.content = [{"type": "text", "text": message.content}]
|
||||
text_block: dict = {"type": "text", "text": message.content}
|
||||
if isinstance(message, AIMessageChunk):
|
||||
text_block["index"] = 0
|
||||
message.content = [text_block]
|
||||
else:
|
||||
if text == "string" and all(
|
||||
isinstance(block, str)
|
||||
@ -1303,13 +1350,20 @@ def _format_contents_as_anthropic(
|
||||
else:
|
||||
content = []
|
||||
for j, block in enumerate(message.content):
|
||||
# OpenAI format
|
||||
if isinstance(block, str):
|
||||
content.append({"type": "text", "text": block})
|
||||
elif block.get("type") == "text":
|
||||
if isinstance(block, dict):
|
||||
block_extra = {
|
||||
k: block[k] for k in _OPTIONAL_ANTHROPIC_KEYS if k in block
|
||||
}
|
||||
else:
|
||||
block_extra = {}
|
||||
|
||||
# OpenAI format
|
||||
if isinstance(block, str):
|
||||
text_block = {"type": "text", "text": block}
|
||||
if isinstance(message, AIMessageChunk):
|
||||
text_block["index"] = 0
|
||||
content.append(text_block)
|
||||
elif block.get("type") == "text":
|
||||
if missing := [k for k in ("text",) if k not in block]:
|
||||
raise ValueError(
|
||||
f"Unrecognized content block at "
|
||||
@ -1380,35 +1434,72 @@ def _format_contents_as_anthropic(
|
||||
f"content block:\n\n{block}"
|
||||
)
|
||||
elif block.get("type") == "tool_use":
|
||||
if missing := [
|
||||
k for k in ("id", "name", "input") if k not in block
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unrecognized content block at "
|
||||
f"messages[{i}].content[{j}] has 'type': 'tool_use', "
|
||||
f"but is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"name": block["name"],
|
||||
"id": block["id"],
|
||||
"input": block["input"],
|
||||
**block_extra,
|
||||
}
|
||||
)
|
||||
if not any(
|
||||
tool_call["id"] == block["id"]
|
||||
for tool_call in cast(AIMessage, message).tool_calls
|
||||
):
|
||||
cast(AIMessage, message).tool_calls.append(
|
||||
create_tool_call(
|
||||
name=block["name"],
|
||||
id=block["id"],
|
||||
args=block["input"],
|
||||
if not isinstance(message, AIMessageChunk):
|
||||
if missing := [
|
||||
k for k in ("id", "name", "input") if k not in block
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unrecognized content block at "
|
||||
f"messages[{i}].content[{j}] has 'type': "
|
||||
f"'tool_use', "
|
||||
f"but is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"name": block["name"],
|
||||
"id": block["id"],
|
||||
"input": block["input"],
|
||||
**block_extra,
|
||||
}
|
||||
)
|
||||
if not any(
|
||||
tool_call["id"] == block["id"]
|
||||
for tool_call in cast(AIMessage, message).tool_calls
|
||||
):
|
||||
cast(AIMessage, message).tool_calls.append(
|
||||
create_tool_call(
|
||||
name=block["name"],
|
||||
id=block["id"],
|
||||
args=block["input"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
if (
|
||||
not any(k in block for k in ("input", "partial_json"))
|
||||
or "index" not in block
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unrecognized content block at "
|
||||
f"message_chunks[{i}].content[{j}] has "
|
||||
f"'type': 'tool_use', "
|
||||
f"but is does not have either an 'input' or "
|
||||
f"'partial_json' and an 'index' key. Full content "
|
||||
f"block:\n\n{block}"
|
||||
)
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"index": block["index"],
|
||||
**{
|
||||
k: block[k]
|
||||
for k in ("name", "input", "partial_json", "id")
|
||||
if k in block
|
||||
},
|
||||
}
|
||||
)
|
||||
if not message.tool_call_chunks:
|
||||
message.tool_call_chunks = [
|
||||
create_tool_call_chunk(
|
||||
name=block.get("name"),
|
||||
id=block.get("id"),
|
||||
index=block["index"],
|
||||
args=block["partial_json"]
|
||||
if "partial_json" in block
|
||||
else block["input"],
|
||||
)
|
||||
]
|
||||
elif block.get("type") == "tool_result":
|
||||
if missing := [
|
||||
k for k in ("content", "tool_use_id") if k not in block
|
||||
@ -1489,7 +1580,35 @@ def _format_contents_as_anthropic(
|
||||
)
|
||||
message.content = content # type: ignore[assignment]
|
||||
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
if isinstance(message, AIMessageChunk) and message.tool_call_chunks:
|
||||
if isinstance(message.content, str):
|
||||
if message.content:
|
||||
message.content = [
|
||||
{"type": "text", "text": message.content, "index": 0}
|
||||
]
|
||||
else:
|
||||
message.content = []
|
||||
if not any(
|
||||
cast(dict, block).get("type") == "tool_use" for block in message.content
|
||||
):
|
||||
tool_use_blocks = [
|
||||
# Note: we intentionally omit index so that it can be set by the
|
||||
# stream handler, which can count how many blocks
|
||||
# have been seen in preceding chunks
|
||||
{
|
||||
"type": "tool_use",
|
||||
"partial_json": tc_chunk["args"],
|
||||
"id": tc_chunk["id"],
|
||||
"name": tc_chunk["name"],
|
||||
}
|
||||
for i, tc_chunk in enumerate(message.tool_call_chunks)
|
||||
]
|
||||
tool_use_blocks = [
|
||||
{k: v for k, v in tu_block.items() if v is not None}
|
||||
for tu_block in tool_use_blocks
|
||||
]
|
||||
message.content.extend(tool_use_blocks)
|
||||
elif isinstance(message, AIMessage) and message.tool_calls:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [{"type": "text", "text": message.content}]
|
||||
for tool_call in message.tool_calls:
|
||||
|
@ -1,11 +1,12 @@
|
||||
import json
|
||||
from typing import Dict, List, Type
|
||||
from typing import Any, Callable, Dict, Iterator, List, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
@ -20,6 +21,7 @@ from langchain_core.messages.utils import (
|
||||
merge_message_runs,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
|
||||
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
|
||||
@ -810,3 +812,580 @@ def test_format_messages_as_declarative() -> None:
|
||||
result = formatter.invoke(messages)
|
||||
assert result[0].content[1]["type"] == "image_url"
|
||||
assert result[0].content[1]["image_url"]["url"] == base64_image
|
||||
|
||||
|
||||
def _stream_oai(input_: Any) -> Iterator:
|
||||
chunks = [
|
||||
AIMessageChunk(content=""),
|
||||
AIMessageChunk(content="Certainly"),
|
||||
AIMessageChunk(content="!"),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": {},
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": "",
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": '{"',
|
||||
"id": None,
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
),
|
||||
]
|
||||
yield from chunks
|
||||
|
||||
|
||||
def _stream_anthropic(input_: Any) -> Iterator:
|
||||
chunks = [
|
||||
AIMessageChunk(content=[]),
|
||||
AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]),
|
||||
AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"input": {},
|
||||
"name": "multiply",
|
||||
"type": "tool_use",
|
||||
"index": 1,
|
||||
}
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": {},
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": "",
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[{"partial_json": '{"', "type": "tool_use", "index": 1}],
|
||||
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": '{"',
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[{"partial_json": "}", "type": "tool_use", "index": 1}],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(content=""),
|
||||
]
|
||||
yield from chunks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
|
||||
def test_format_messages_openai_string_stream(stream: Callable) -> None:
|
||||
formatter = format_messages_as(format="openai", text="string")
|
||||
|
||||
chain = RunnableLambda(stream) | formatter
|
||||
tool_call_idx = 1 if stream == _stream_anthropic else 0
|
||||
expected = [
|
||||
AIMessageChunk(content=""),
|
||||
AIMessageChunk(content="Certainly"),
|
||||
AIMessageChunk(content="!"),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": {},
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": "",
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"index": tool_call_idx,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": '{"',
|
||||
"id": None,
|
||||
"index": tool_call_idx,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"index": tool_call_idx,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"index": tool_call_idx,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
),
|
||||
]
|
||||
|
||||
actual = list(chain.stream({}))
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
|
||||
def test_format_messages_openai_block_stream(stream: Callable) -> None:
|
||||
formatter = format_messages_as(format="openai", text="block")
|
||||
|
||||
chain = RunnableLambda(stream) | formatter
|
||||
tool_call_idx = 1 if stream == _stream_anthropic else 0
|
||||
expected = [
|
||||
AIMessageChunk(content=[]),
|
||||
AIMessageChunk(content=[{"type": "text", "text": "Certainly"}]),
|
||||
AIMessageChunk(content=[{"type": "text", "text": "!"}]),
|
||||
AIMessageChunk(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": {},
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": "",
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"index": tool_call_idx,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[],
|
||||
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": '{"',
|
||||
"id": None,
|
||||
"index": tool_call_idx,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"index": tool_call_idx,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"index": tool_call_idx,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[],
|
||||
),
|
||||
]
|
||||
actual = list(chain.stream({}))
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
|
||||
def test_format_messages_anthropic_block_stream(stream: Callable) -> None:
|
||||
formatter = format_messages_as(format="anthropic", text="block")
|
||||
|
||||
chain = RunnableLambda(stream) | formatter
|
||||
expected = [
|
||||
AIMessageChunk(content=[]),
|
||||
AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]),
|
||||
AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"name": "multiply",
|
||||
"type": "tool_use",
|
||||
"index": 1,
|
||||
**(
|
||||
{"input": {}}
|
||||
if stream == _stream_anthropic
|
||||
else {"partial_json": ""}
|
||||
),
|
||||
}
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": {},
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": "",
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[{"partial_json": '{"', "type": "tool_use", "index": 1}],
|
||||
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": '{"',
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[{"partial_json": "}", "type": "tool_use", "index": 1}],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(content=[]),
|
||||
]
|
||||
actual = list(chain.stream({}))
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
|
||||
def test_format_messages_anthropic_string_stream(stream: Callable) -> None:
|
||||
formatter = format_messages_as(format="anthropic", text="string")
|
||||
|
||||
chain = RunnableLambda(stream) | formatter
|
||||
expected = [
|
||||
AIMessageChunk(content=""),
|
||||
AIMessageChunk(content="Certainly"),
|
||||
AIMessageChunk(content="!"),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"name": "multiply",
|
||||
"index": 1,
|
||||
**(
|
||||
{"input": {}}
|
||||
if stream == _stream_anthropic
|
||||
else {"partial_json": ""}
|
||||
),
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": {},
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": "multiply",
|
||||
"args": "",
|
||||
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{"type": "tool_use", "partial_json": '{"', "index": 1},
|
||||
],
|
||||
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": '{"',
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{"type": "tool_use", "partial_json": 'a": 5, "b": 2', "index": 1},
|
||||
],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": 'a": 5, "b": 2',
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(
|
||||
content=[
|
||||
{"type": "tool_use", "partial_json": "}", "index": 1},
|
||||
],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": None,
|
||||
"args": "}",
|
||||
"id": None,
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessageChunk(content=""),
|
||||
]
|
||||
actual = list(chain.stream({}))
|
||||
assert expected == actual
|
||||
|
@ -31,7 +31,7 @@ from langchain_core.language_models.chat_models import (
|
||||
)
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tracers import RunLog, RunLogPatch
|
||||
@ -512,7 +512,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
return model
|
||||
|
||||
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
model_params = {
|
||||
_remove_prefix(k, self._config_prefix): v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
|
Loading…
Reference in New Issue
Block a user