feat(core): support returning v1 ToolMessage in tools (#32397)

This commit is contained in:
ccurme 2025-08-05 09:50:02 -03:00 committed by GitHub
parent b06dc7954e
commit deae8cc164
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 373 additions and 75 deletions

View File

@ -19,6 +19,7 @@ from langchain_core.messages.ai import (
add_usage,
)
from langchain_core.messages.base import merge_content
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.utils._merge import merge_dicts
@ -645,7 +646,7 @@ class SystemMessage:
@dataclass
class ToolMessage:
class ToolMessage(ToolOutputMixin):
"""A message containing the result of a tool execution.
Represents the output from executing a tool or function call,

View File

@ -2361,6 +2361,7 @@ class Runnable(ABC, Generic[Input, Output]):
name: Optional[str] = None,
description: Optional[str] = None,
arg_types: Optional[dict[str, type]] = None,
output_version: Literal["v0", "v1"] = "v0",
) -> BaseTool:
"""Create a BaseTool from a Runnable.
@ -2376,6 +2377,11 @@ class Runnable(ABC, Generic[Input, Output]):
name: The name of the tool. Defaults to None.
description: The description of the tool. Defaults to None.
arg_types: A dictionary of argument names to types. Defaults to None.
output_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`.
Returns:
A BaseTool instance.
@ -2451,7 +2457,7 @@ class Runnable(ABC, Generic[Input, Output]):
.. versionadded:: 0.2.14
"""
""" # noqa: E501
# Avoid circular import
from langchain_core.tools import convert_runnable_to_tool
@ -2461,6 +2467,7 @@ class Runnable(ABC, Generic[Input, Output]):
name=name,
description=description,
arg_types=arg_types,
output_version=output_version,
)

View File

@ -47,6 +47,7 @@ from langchain_core.callbacks import (
Callbacks,
)
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolOutputMixin
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
@ -498,6 +499,14 @@ class ChildTool(BaseTool):
two-tuple corresponding to the (content, artifact) of a ToolMessage.
"""
output_version: Literal["v0", "v1"] = "v0"
"""Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`.
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the tool."""
if (
@ -835,7 +844,7 @@ class ChildTool(BaseTool):
content = None
artifact = None
status = "success"
status: Literal["success", "error"] = "success"
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -879,7 +888,14 @@ class ChildTool(BaseTool):
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name, status)
output = _format_output(
content,
artifact,
tool_call_id,
self.name,
status,
output_version=self.output_version,
)
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@ -945,7 +961,7 @@ class ChildTool(BaseTool):
)
content = None
artifact = None
status = "success"
status: Literal["success", "error"] = "success"
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
@ -993,7 +1009,14 @@ class ChildTool(BaseTool):
await run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name, status)
output = _format_output(
content,
artifact,
tool_call_id,
self.name,
status,
output_version=self.output_version,
)
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@ -1131,7 +1154,9 @@ def _format_output(
artifact: Any,
tool_call_id: Optional[str],
name: str,
status: str,
status: Literal["success", "error"],
*,
output_version: Literal["v0", "v1"] = "v0",
) -> Union[ToolOutputMixin, Any]:
"""Format tool output as a ToolMessage if appropriate.
@ -1141,6 +1166,7 @@ def _format_output(
tool_call_id: The ID of the tool call.
name: The name of the tool.
status: The execution status.
output_version: The version of the ToolMessage to return.
Returns:
The formatted output, either as a ToolMessage or the original content.
@ -1149,7 +1175,15 @@ def _format_output(
return content
if not _is_message_content_type(content):
content = _stringify(content)
return ToolMessage(
if output_version == "v0":
return ToolMessage(
content,
artifact=artifact,
tool_call_id=tool_call_id,
name=name,
status=status,
)
return ToolMessageV1(
content,
artifact=artifact,
tool_call_id=tool_call_id,

View File

@ -22,6 +22,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
output_version: Literal["v0", "v1"] = "v0",
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
@ -37,6 +38,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
output_version: Literal["v0", "v1"] = "v0",
) -> BaseTool: ...
@ -51,6 +53,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
output_version: Literal["v0", "v1"] = "v0",
) -> BaseTool: ...
@ -65,6 +68,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
output_version: Literal["v0", "v1"] = "v0",
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
@ -79,6 +83,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
output_version: Literal["v0", "v1"] = "v0",
) -> Union[
BaseTool,
Callable[[Union[Callable, Runnable]], BaseTool],
@ -118,6 +123,11 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
output_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`.
Returns:
The tool.
@ -216,7 +226,7 @@ def tool(
\"\"\"
return bar
""" # noqa: D214, D410, D411
""" # noqa: D214, D410, D411, E501
def _create_tool_factory(
tool_name: str,
@ -274,6 +284,7 @@ def tool(
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
output_version=output_version,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
@ -290,6 +301,7 @@ def tool(
return_direct=return_direct,
coroutine=coroutine,
response_format=response_format,
output_version=output_version,
)
return _tool_factory
@ -383,6 +395,7 @@ def convert_runnable_to_tool(
name: Optional[str] = None,
description: Optional[str] = None,
arg_types: Optional[dict[str, type]] = None,
output_version: Literal["v0", "v1"] = "v0",
) -> BaseTool:
"""Convert a Runnable into a BaseTool.
@ -392,10 +405,15 @@ def convert_runnable_to_tool(
name: The name of the tool. Defaults to None.
description: The description of the tool. Defaults to None.
arg_types: The types of the arguments. Defaults to None.
output_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`.
Returns:
The tool.
"""
""" # noqa: E501
if args_schema:
runnable = runnable.with_types(input_type=args_schema)
description = description or _get_description_from_runnable(runnable)
@ -408,6 +426,7 @@ def convert_runnable_to_tool(
func=runnable.invoke,
coroutine=runnable.ainvoke,
description=description,
output_version=output_version,
)
async def ainvoke_wrapper(
@ -435,4 +454,5 @@ def convert_runnable_to_tool(
coroutine=ainvoke_wrapper,
description=description,
args_schema=args_schema,
output_version=output_version,
)

View File

@ -72,6 +72,7 @@ def create_retriever_tool(
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
response_format: Literal["content", "content_and_artifact"] = "content",
output_version: Literal["v0", "v1"] = "v1",
) -> Tool:
r"""Create a tool to do retrieval of documents.
@ -88,10 +89,15 @@ def create_retriever_tool(
"content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage (artifact
being a list of documents in this case). Defaults to "content".
output_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`.
Returns:
Tool class to pass to an agent.
"""
""" # noqa: E501
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial(
_get_relevant_documents,
@ -114,4 +120,5 @@ def create_retriever_tool(
coroutine=afunc,
args_schema=RetrieverInput,
response_format=response_format,
output_version=output_version,
)

View File

@ -129,6 +129,7 @@ class StructuredTool(BaseTool):
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
output_version: Literal["v0", "v1"] = "v0",
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
@ -157,6 +158,12 @@ class StructuredTool(BaseTool):
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to False.
output_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`.
kwargs: Additional arguments to pass to the tool
Returns:
@ -175,7 +182,7 @@ class StructuredTool(BaseTool):
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
""" # noqa: E501
if func is not None:
source_function = func
elif coroutine is not None:
@ -232,6 +239,7 @@ class StructuredTool(BaseTool):
description=description_,
return_direct=return_direct,
response_format=response_format,
output_version=output_version,
**kwargs,
)

View File

@ -37,6 +37,7 @@ from langchain_core.callbacks.manager import (
from langchain_core.documents import Document
from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
Runnable,
@ -70,6 +71,7 @@ from langchain_core.utils.pydantic import (
)
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _schema
from tests.unit_tests.stubs import AnyStr
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
@ -1379,17 +1381,28 @@ def test_tool_annotated_descriptions() -> None:
}
def test_tool_call_input_tool_message_output() -> None:
@pytest.mark.parametrize("output_version", ["v0", "v1"])
def test_tool_call_input_tool_message(output_version: Literal["v0", "v1"]) -> None:
tool_call = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
"type": "tool_call",
}
tool = _MockStructuredTool()
expected = ToolMessage(
"1 True {'img': 'base64string...'}", tool_call_id="123", name="structured_api"
)
tool = _MockStructuredTool(output_version=output_version)
if output_version == "v0":
expected: Union[ToolMessage, ToolMessageV1] = ToolMessage(
"1 True {'img': 'base64string...'}",
tool_call_id="123",
name="structured_api",
)
else:
expected = ToolMessageV1(
"1 True {'img': 'base64string...'}",
tool_call_id="123",
name="structured_api",
id=AnyStr("lc_abc123"),
)
actual = tool.invoke(tool_call)
assert actual == expected
@ -1421,6 +1434,14 @@ def _mock_structured_tool_with_artifact(
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@tool("structured_api", response_format="content_and_artifact", output_version="v1")
def _mock_structured_tool_with_artifact_v1(
*, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> tuple[str, dict]:
"""A Structured Tool."""
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@pytest.mark.parametrize(
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact]
)
@ -1445,6 +1466,38 @@ def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None:
assert actual_content == expected.content
@pytest.mark.parametrize(
"tool",
[
_MockStructuredToolWithRawOutput(output_version="v1"),
_mock_structured_tool_with_artifact_v1,
],
)
def test_tool_call_input_tool_message_with_artifact_v1(tool: BaseTool) -> None:
tool_call: dict = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
"type": "tool_call",
}
expected = ToolMessageV1(
"1 True",
artifact=tool_call["args"],
tool_call_id="123",
name="structured_api",
id=AnyStr("lc_abc123"),
)
actual = tool.invoke(tool_call)
assert actual == expected
tool_call.pop("type")
with pytest.raises(ValidationError):
tool.invoke(tool_call)
actual_content = tool.invoke(tool_call["args"])
assert actual_content == expected.text
def test_convert_from_runnable_dict() -> None:
# Test with typed dict input
class Args(TypedDict):
@ -1550,6 +1603,17 @@ def injected_tool(x: int, y: Annotated[str, InjectedToolArg]) -> str:
return y
@tool("foo", parse_docstring=True, output_version="v1")
def injected_tool_v1(x: int, y: Annotated[str, InjectedToolArg]) -> str:
"""Foo.
Args:
x: abc
y: 123
"""
return y
class InjectedTool(BaseTool):
name: str = "foo"
description: str = "foo."
@ -1587,7 +1651,12 @@ def injected_tool_with_schema(x: int, y: str) -> str:
return y
@pytest.mark.parametrize("tool_", [InjectedTool()])
@tool("foo", args_schema=fooSchema, output_version="v1")
def injected_tool_with_schema_v1(x: int, y: str) -> str:
return y
@pytest.mark.parametrize("tool_", [InjectedTool(), InjectedTool(output_version="v1")])
def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
assert _schema(tool_.get_input_schema()) == {
"title": "foo",
@ -1607,14 +1676,25 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
"required": ["x"],
}
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
assert tool_.invoke(
{
"name": "foo",
"args": {"x": 5, "y": "bar"},
"id": "123",
"type": "tool_call",
}
) == ToolMessage("bar", tool_call_id="123", name="foo")
if tool_.output_version == "v0":
expected: Union[ToolMessage, ToolMessageV1] = ToolMessage(
"bar", tool_call_id="123", name="foo"
)
else:
expected = ToolMessageV1(
"bar", tool_call_id="123", name="foo", id=AnyStr("lc_abc123")
)
assert (
tool_.invoke(
{
"name": "foo",
"args": {"x": 5, "y": "bar"},
"id": "123",
"type": "tool_call",
}
)
== expected
)
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
@ -1634,7 +1714,12 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
@pytest.mark.parametrize(
"tool_",
[injected_tool_with_schema, InjectedToolWithSchema()],
[
injected_tool_with_schema,
InjectedToolWithSchema(),
injected_tool_with_schema_v1,
InjectedToolWithSchema(output_version="v1"),
],
)
def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
assert _schema(tool_.get_input_schema()) == {
@ -1655,14 +1740,25 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
"required": ["x"],
}
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
assert tool_.invoke(
{
"name": "foo",
"args": {"x": 5, "y": "bar"},
"id": "123",
"type": "tool_call",
}
) == ToolMessage("bar", tool_call_id="123", name="foo")
if tool_.output_version == "v0":
expected: Union[ToolMessage, ToolMessageV1] = ToolMessage(
"bar", tool_call_id="123", name="foo"
)
else:
expected = ToolMessageV1(
"bar", tool_call_id="123", name="foo", id=AnyStr("lc_abc123")
)
assert (
tool_.invoke(
{
"name": "foo",
"args": {"x": 5, "y": "bar"},
"id": "123",
"type": "tool_call",
}
)
== expected
)
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
@ -1680,8 +1776,9 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
}
def test_tool_injected_arg() -> None:
tool_ = injected_tool
@pytest.mark.parametrize("output_version", ["v0", "v1"])
def test_tool_injected_arg(output_version: Literal["v0", "v1"]) -> None:
tool_ = injected_tool if output_version == "v0" else injected_tool_v1
assert _schema(tool_.get_input_schema()) == {
"title": "foo",
"description": "Foo.",
@ -1700,14 +1797,25 @@ def test_tool_injected_arg() -> None:
"required": ["x"],
}
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
assert tool_.invoke(
{
"name": "foo",
"args": {"x": 5, "y": "bar"},
"id": "123",
"type": "tool_call",
}
) == ToolMessage("bar", tool_call_id="123", name="foo")
if output_version == "v0":
expected: Union[ToolMessage, ToolMessageV1] = ToolMessage(
"bar", tool_call_id="123", name="foo"
)
else:
expected = ToolMessageV1(
"bar", tool_call_id="123", name="foo", id=AnyStr("lc_abc123")
)
assert (
tool_.invoke(
{
"name": "foo",
"args": {"x": 5, "y": "bar"},
"id": "123",
"type": "tool_call",
}
)
== expected
)
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
@ -1725,7 +1833,8 @@ def test_tool_injected_arg() -> None:
}
def test_tool_inherited_injected_arg() -> None:
@pytest.mark.parametrize("output_version", ["v0", "v1"])
def test_tool_inherited_injected_arg(output_version: Literal["v0", "v1"]) -> None:
class BarSchema(BaseModel):
"""bar."""
@ -1746,7 +1855,7 @@ def test_tool_inherited_injected_arg() -> None:
def _run(self, x: int, y: str) -> Any:
return y
tool_ = InheritedInjectedArgTool()
tool_ = InheritedInjectedArgTool(output_version=output_version)
assert tool_.get_input_schema().model_json_schema() == {
"title": "FooSchema", # Matches the title from the provided schema
"description": "foo.",
@ -1766,14 +1875,25 @@ def test_tool_inherited_injected_arg() -> None:
"required": ["x"],
}
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
assert tool_.invoke(
{
"name": "foo",
"args": {"x": 5, "y": "bar"},
"id": "123",
"type": "tool_call",
}
) == ToolMessage("bar", tool_call_id="123", name="foo")
if output_version == "v0":
expected: Union[ToolMessage, ToolMessageV1] = ToolMessage(
"bar", tool_call_id="123", name="foo"
)
else:
expected = ToolMessageV1(
"bar", tool_call_id="123", name="foo", id=AnyStr("lc_abc123")
)
assert (
tool_.invoke(
{
"name": "foo",
"args": {"x": 5, "y": "bar"},
"id": "123",
"type": "tool_call",
}
)
== expected
)
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
@ -2133,7 +2253,8 @@ def test_tool_annotations_preserved() -> None:
assert schema.__annotations__ == expected_type_hints
def test_create_retriever_tool() -> None:
@pytest.mark.parametrize("output_version", ["v0", "v1"])
def test_create_retriever_tool(output_version: Literal["v0", "v1"]) -> None:
class MyRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
@ -2142,21 +2263,36 @@ def test_create_retriever_tool() -> None:
retriever = MyRetriever()
retriever_tool = tools.create_retriever_tool(
retriever, "retriever_tool_content", "Retriever Tool Content"
retriever,
"retriever_tool_content",
"Retriever Tool Content",
output_version=output_version,
)
assert isinstance(retriever_tool, BaseTool)
assert retriever_tool.name == "retriever_tool_content"
assert retriever_tool.description == "Retriever Tool Content"
assert retriever_tool.invoke("bar") == "foo bar\n\nbar"
assert retriever_tool.invoke(
ToolCall(
name="retriever_tool_content",
args={"query": "bar"},
id="123",
type="tool_call",
if output_version == "v0":
expected: Union[ToolMessage, ToolMessageV1] = ToolMessage(
"foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content"
)
) == ToolMessage(
"foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content"
else:
expected = ToolMessageV1(
"foo bar\n\nbar",
tool_call_id="123",
name="retriever_tool_content",
id=AnyStr("lc_abc123"),
)
assert (
retriever_tool.invoke(
ToolCall(
name="retriever_tool_content",
args={"query": "bar"},
id="123",
type="tool_call",
)
)
== expected
)
retriever_tool_artifact = tools.create_retriever_tool(
@ -2164,23 +2300,37 @@ def test_create_retriever_tool() -> None:
"retriever_tool_artifact",
"Retriever Tool Artifact",
response_format="content_and_artifact",
output_version=output_version,
)
assert isinstance(retriever_tool_artifact, BaseTool)
assert retriever_tool_artifact.name == "retriever_tool_artifact"
assert retriever_tool_artifact.description == "Retriever Tool Artifact"
assert retriever_tool_artifact.invoke("bar") == "foo bar\n\nbar"
assert retriever_tool_artifact.invoke(
ToolCall(
if output_version == "v0":
expected = ToolMessage(
"foo bar\n\nbar",
artifact=[Document(page_content="foo bar"), Document(page_content="bar")],
tool_call_id="123",
name="retriever_tool_artifact",
args={"query": "bar"},
id="123",
type="tool_call",
)
) == ToolMessage(
"foo bar\n\nbar",
artifact=[Document(page_content="foo bar"), Document(page_content="bar")],
tool_call_id="123",
name="retriever_tool_artifact",
else:
expected = ToolMessageV1(
"foo bar\n\nbar",
artifact=[Document(page_content="foo bar"), Document(page_content="bar")],
tool_call_id="123",
name="retriever_tool_artifact",
id=AnyStr("lc_abc123"),
)
assert (
retriever_tool_artifact.invoke(
ToolCall(
name="retriever_tool_artifact",
args={"query": "bar"},
id="123",
type="tool_call",
)
)
== expected
)
@ -2313,6 +2463,45 @@ def test_tool_injected_tool_call_id() -> None:
) == ToolMessage(0, tool_call_id="bar") # type: ignore[arg-type]
def test_tool_injected_tool_call_id_v1() -> None:
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessageV1:
"""Foo."""
return ToolMessageV1(str(x), tool_call_id=tool_call_id)
assert foo.invoke(
{
"type": "tool_call",
"args": {"x": 0},
"name": "foo",
"id": "bar",
}
) == ToolMessageV1("0", tool_call_id="bar", id=AnyStr("lc_abc123"))
with pytest.raises(
ValueError,
match="When tool includes an InjectedToolCallId argument, "
"tool must always be invoked with a full model ToolCall",
):
assert foo.invoke({"x": 0})
@tool
def foo2(
x: int, tool_call_id: Annotated[str, InjectedToolCallId()]
) -> ToolMessageV1:
"""Foo."""
return ToolMessageV1(str(x), tool_call_id=tool_call_id)
assert foo2.invoke(
{
"type": "tool_call",
"args": {"x": 0},
"name": "foo",
"id": "bar",
}
) == ToolMessageV1("0", tool_call_id="bar", id=AnyStr("lc_abc123"))
def test_tool_uninjected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: str) -> ToolMessage:
@ -2332,6 +2521,25 @@ def test_tool_uninjected_tool_call_id() -> None:
) == ToolMessage(0, tool_call_id="zap") # type: ignore[arg-type]
def test_tool_uninjected_tool_call_id_v1() -> None:
@tool
def foo(x: int, tool_call_id: str) -> ToolMessageV1:
"""Foo."""
return ToolMessageV1(str(x), tool_call_id=tool_call_id)
with pytest.raises(ValueError, match="1 validation error for foo"):
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
assert foo.invoke(
{
"type": "tool_call",
"args": {"x": 0, "tool_call_id": "zap"},
"name": "foo",
"id": "bar",
}
) == ToolMessageV1("0", tool_call_id="zap", id=AnyStr("lc_abc123"))
def test_tool_return_output_mixin() -> None:
class Bar(ToolOutputMixin):
def __init__(self, x: int) -> None:
@ -2457,6 +2665,19 @@ def test_empty_string_tool_call_id() -> None:
)
def test_empty_string_tool_call_id_v1() -> None:
@tool(output_version="v1")
def foo(x: int) -> str:
"""Foo."""
return "hi"
assert foo.invoke(
{"type": "tool_call", "args": {"x": 0}, "id": ""}
) == ToolMessageV1(
content="hi", name="foo", tool_call_id="", id=AnyStr("lc_abc123")
)
def test_tool_decorator_description() -> None:
# test basic tool
@tool