diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py index 1b779e4afd4..159e2faacaf 100644 --- a/libs/core/langchain_core/messages/v1.py +++ b/libs/core/langchain_core/messages/v1.py @@ -19,6 +19,7 @@ from langchain_core.messages.ai import ( add_usage, ) from langchain_core.messages.base import merge_content +from langchain_core.messages.tool import ToolOutputMixin from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.utils._merge import merge_dicts @@ -645,7 +646,7 @@ class SystemMessage: @dataclass -class ToolMessage: +class ToolMessage(ToolOutputMixin): """A message containing the result of a tool execution. Represents the output from executing a tool or function call, diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 6c3938631ea..30f11c9ab7e 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2361,6 +2361,7 @@ class Runnable(ABC, Generic[Input, Output]): name: Optional[str] = None, description: Optional[str] = None, arg_types: Optional[dict[str, type]] = None, + output_version: Literal["v0", "v1"] = "v0", ) -> BaseTool: """Create a BaseTool from a Runnable. @@ -2376,6 +2377,11 @@ class Runnable(ABC, Generic[Input, Output]): name: The name of the tool. Defaults to None. description: The description of the tool. Defaults to None. arg_types: A dictionary of argument names to types. Defaults to None. + output_version: Version of ToolMessage to return given + :class:`~langchain_core.messages.content_blocks.ToolCall` input. + + If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`. + If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`. Returns: A BaseTool instance. @@ -2451,7 +2457,7 @@ class Runnable(ABC, Generic[Input, Output]): .. versionadded:: 0.2.14 - """ + """ # noqa: E501 # Avoid circular import from langchain_core.tools import convert_runnable_to_tool @@ -2461,6 +2467,7 @@ class Runnable(ABC, Generic[Input, Output]): name=name, description=description, arg_types=arg_types, + output_version=output_version, ) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index e54a09709d6..5573a08a1d5 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -47,6 +47,7 @@ from langchain_core.callbacks import ( Callbacks, ) from langchain_core.messages.tool import ToolCall, ToolMessage, ToolOutputMixin +from langchain_core.messages.v1 import ToolMessage as ToolMessageV1 from langchain_core.runnables import ( RunnableConfig, RunnableSerializable, @@ -498,6 +499,14 @@ class ChildTool(BaseTool): two-tuple corresponding to the (content, artifact) of a ToolMessage. """ + output_version: Literal["v0", "v1"] = "v0" + """Version of ToolMessage to return given + :class:`~langchain_core.messages.content_blocks.ToolCall` input. + + If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`. + If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`. + """ + def __init__(self, **kwargs: Any) -> None: """Initialize the tool.""" if ( @@ -835,7 +844,7 @@ class ChildTool(BaseTool): content = None artifact = None - status = "success" + status: Literal["success", "error"] = "success" error_to_raise: Union[Exception, KeyboardInterrupt, None] = None try: child_config = patch_config(config, callbacks=run_manager.get_child()) @@ -879,7 +888,14 @@ class ChildTool(BaseTool): if error_to_raise: run_manager.on_tool_error(error_to_raise) raise error_to_raise - output = _format_output(content, artifact, tool_call_id, self.name, status) + output = _format_output( + content, + artifact, + tool_call_id, + self.name, + status, + output_version=self.output_version, + ) run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) return output @@ -945,7 +961,7 @@ class ChildTool(BaseTool): ) content = None artifact = None - status = "success" + status: Literal["success", "error"] = "success" error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None try: tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) @@ -993,7 +1009,14 @@ class ChildTool(BaseTool): await run_manager.on_tool_error(error_to_raise) raise error_to_raise - output = _format_output(content, artifact, tool_call_id, self.name, status) + output = _format_output( + content, + artifact, + tool_call_id, + self.name, + status, + output_version=self.output_version, + ) await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) return output @@ -1131,7 +1154,9 @@ def _format_output( artifact: Any, tool_call_id: Optional[str], name: str, - status: str, + status: Literal["success", "error"], + *, + output_version: Literal["v0", "v1"] = "v0", ) -> Union[ToolOutputMixin, Any]: """Format tool output as a ToolMessage if appropriate. @@ -1141,6 +1166,7 @@ def _format_output( tool_call_id: The ID of the tool call. name: The name of the tool. status: The execution status. + output_version: The version of the ToolMessage to return. Returns: The formatted output, either as a ToolMessage or the original content. @@ -1149,7 +1175,15 @@ def _format_output( return content if not _is_message_content_type(content): content = _stringify(content) - return ToolMessage( + if output_version == "v0": + return ToolMessage( + content, + artifact=artifact, + tool_call_id=tool_call_id, + name=name, + status=status, + ) + return ToolMessageV1( content, artifact=artifact, tool_call_id=tool_call_id, diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index 8b103fd54d6..dcbfae56225 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -22,6 +22,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, + output_version: Literal["v0", "v1"] = "v0", ) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... @@ -37,6 +38,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, + output_version: Literal["v0", "v1"] = "v0", ) -> BaseTool: ... @@ -51,6 +53,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, + output_version: Literal["v0", "v1"] = "v0", ) -> BaseTool: ... @@ -65,6 +68,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, + output_version: Literal["v0", "v1"] = "v0", ) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... @@ -79,6 +83,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, + output_version: Literal["v0", "v1"] = "v0", ) -> Union[ BaseTool, Callable[[Union[Callable, Runnable]], BaseTool], @@ -118,6 +123,11 @@ def tool( error_on_invalid_docstring: if ``parse_docstring`` is provided, configure whether to raise ValueError on invalid Google Style docstrings. Defaults to True. + output_version: Version of ToolMessage to return given + :class:`~langchain_core.messages.content_blocks.ToolCall` input. + + If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`. + If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`. Returns: The tool. @@ -216,7 +226,7 @@ def tool( \"\"\" return bar - """ # noqa: D214, D410, D411 + """ # noqa: D214, D410, D411, E501 def _create_tool_factory( tool_name: str, @@ -274,6 +284,7 @@ def tool( response_format=response_format, parse_docstring=parse_docstring, error_on_invalid_docstring=error_on_invalid_docstring, + output_version=output_version, ) # If someone doesn't want a schema applied, we must treat it as # a simple string->string function @@ -290,6 +301,7 @@ def tool( return_direct=return_direct, coroutine=coroutine, response_format=response_format, + output_version=output_version, ) return _tool_factory @@ -383,6 +395,7 @@ def convert_runnable_to_tool( name: Optional[str] = None, description: Optional[str] = None, arg_types: Optional[dict[str, type]] = None, + output_version: Literal["v0", "v1"] = "v0", ) -> BaseTool: """Convert a Runnable into a BaseTool. @@ -392,10 +405,15 @@ def convert_runnable_to_tool( name: The name of the tool. Defaults to None. description: The description of the tool. Defaults to None. arg_types: The types of the arguments. Defaults to None. + output_version: Version of ToolMessage to return given + :class:`~langchain_core.messages.content_blocks.ToolCall` input. + + If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`. + If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`. Returns: The tool. - """ + """ # noqa: E501 if args_schema: runnable = runnable.with_types(input_type=args_schema) description = description or _get_description_from_runnable(runnable) @@ -408,6 +426,7 @@ def convert_runnable_to_tool( func=runnable.invoke, coroutine=runnable.ainvoke, description=description, + output_version=output_version, ) async def ainvoke_wrapper( @@ -435,4 +454,5 @@ def convert_runnable_to_tool( coroutine=ainvoke_wrapper, description=description, args_schema=args_schema, + output_version=output_version, ) diff --git a/libs/core/langchain_core/tools/retriever.py b/libs/core/langchain_core/tools/retriever.py index 002fa5e80d6..b3f3f4be10e 100644 --- a/libs/core/langchain_core/tools/retriever.py +++ b/libs/core/langchain_core/tools/retriever.py @@ -72,6 +72,7 @@ def create_retriever_tool( document_prompt: Optional[BasePromptTemplate] = None, document_separator: str = "\n\n", response_format: Literal["content", "content_and_artifact"] = "content", + output_version: Literal["v0", "v1"] = "v1", ) -> Tool: r"""Create a tool to do retrieval of documents. @@ -88,10 +89,15 @@ def create_retriever_tool( "content_and_artifact" then the output is expected to be a two-tuple corresponding to the (content, artifact) of a ToolMessage (artifact being a list of documents in this case). Defaults to "content". + output_version: Version of ToolMessage to return given + :class:`~langchain_core.messages.content_blocks.ToolCall` input. + + If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`. + If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`. Returns: Tool class to pass to an agent. - """ + """ # noqa: E501 document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") func = partial( _get_relevant_documents, @@ -114,4 +120,5 @@ def create_retriever_tool( coroutine=afunc, args_schema=RetrieverInput, response_format=response_format, + output_version=output_version, ) diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index a419a1ede62..c1326512ee2 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -129,6 +129,7 @@ class StructuredTool(BaseTool): response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = False, + output_version: Literal["v0", "v1"] = "v0", **kwargs: Any, ) -> StructuredTool: """Create tool from a given function. @@ -157,6 +158,12 @@ class StructuredTool(BaseTool): error_on_invalid_docstring: if ``parse_docstring`` is provided, configure whether to raise ValueError on invalid Google Style docstrings. Defaults to False. + output_version: Version of ToolMessage to return given + :class:`~langchain_core.messages.content_blocks.ToolCall` input. + + If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`. + If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`. + kwargs: Additional arguments to pass to the tool Returns: @@ -175,7 +182,7 @@ class StructuredTool(BaseTool): tool = StructuredTool.from_function(add) tool.run(1, 2) # 3 - """ + """ # noqa: E501 if func is not None: source_function = func elif coroutine is not None: @@ -232,6 +239,7 @@ class StructuredTool(BaseTool): description=description_, return_direct=return_direct, response_format=response_format, + output_version=output_version, **kwargs, ) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 57b4573d70d..cd52d37c9fa 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -37,6 +37,7 @@ from langchain_core.callbacks.manager import ( from langchain_core.documents import Document from langchain_core.messages import ToolCall, ToolMessage from langchain_core.messages.tool import ToolOutputMixin +from langchain_core.messages.v1 import ToolMessage as ToolMessageV1 from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( Runnable, @@ -70,6 +71,7 @@ from langchain_core.utils.pydantic import ( ) from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.pydantic_utils import _schema +from tests.unit_tests.stubs import AnyStr def _get_tool_call_json_schema(tool: BaseTool) -> dict: @@ -1379,17 +1381,28 @@ def test_tool_annotated_descriptions() -> None: } -def test_tool_call_input_tool_message_output() -> None: +@pytest.mark.parametrize("output_version", ["v0", "v1"]) +def test_tool_call_input_tool_message(output_version: Literal["v0", "v1"]) -> None: tool_call = { "name": "structured_api", "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, "id": "123", "type": "tool_call", } - tool = _MockStructuredTool() - expected = ToolMessage( - "1 True {'img': 'base64string...'}", tool_call_id="123", name="structured_api" - ) + tool = _MockStructuredTool(output_version=output_version) + if output_version == "v0": + expected: Union[ToolMessage, ToolMessageV1] = ToolMessage( + "1 True {'img': 'base64string...'}", + tool_call_id="123", + name="structured_api", + ) + else: + expected = ToolMessageV1( + "1 True {'img': 'base64string...'}", + tool_call_id="123", + name="structured_api", + id=AnyStr("lc_abc123"), + ) actual = tool.invoke(tool_call) assert actual == expected @@ -1421,6 +1434,14 @@ def _mock_structured_tool_with_artifact( return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} +@tool("structured_api", response_format="content_and_artifact", output_version="v1") +def _mock_structured_tool_with_artifact_v1( + *, arg1: int, arg2: bool, arg3: Optional[dict] = None +) -> tuple[str, dict]: + """A Structured Tool.""" + return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} + + @pytest.mark.parametrize( "tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact] ) @@ -1445,6 +1466,38 @@ def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None: assert actual_content == expected.content +@pytest.mark.parametrize( + "tool", + [ + _MockStructuredToolWithRawOutput(output_version="v1"), + _mock_structured_tool_with_artifact_v1, + ], +) +def test_tool_call_input_tool_message_with_artifact_v1(tool: BaseTool) -> None: + tool_call: dict = { + "name": "structured_api", + "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, + "id": "123", + "type": "tool_call", + } + expected = ToolMessageV1( + "1 True", + artifact=tool_call["args"], + tool_call_id="123", + name="structured_api", + id=AnyStr("lc_abc123"), + ) + actual = tool.invoke(tool_call) + assert actual == expected + + tool_call.pop("type") + with pytest.raises(ValidationError): + tool.invoke(tool_call) + + actual_content = tool.invoke(tool_call["args"]) + assert actual_content == expected.text + + def test_convert_from_runnable_dict() -> None: # Test with typed dict input class Args(TypedDict): @@ -1550,6 +1603,17 @@ def injected_tool(x: int, y: Annotated[str, InjectedToolArg]) -> str: return y +@tool("foo", parse_docstring=True, output_version="v1") +def injected_tool_v1(x: int, y: Annotated[str, InjectedToolArg]) -> str: + """Foo. + + Args: + x: abc + y: 123 + """ + return y + + class InjectedTool(BaseTool): name: str = "foo" description: str = "foo." @@ -1587,7 +1651,12 @@ def injected_tool_with_schema(x: int, y: str) -> str: return y -@pytest.mark.parametrize("tool_", [InjectedTool()]) +@tool("foo", args_schema=fooSchema, output_version="v1") +def injected_tool_with_schema_v1(x: int, y: str) -> str: + return y + + +@pytest.mark.parametrize("tool_", [InjectedTool(), InjectedTool(output_version="v1")]) def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None: assert _schema(tool_.get_input_schema()) == { "title": "foo", @@ -1607,14 +1676,25 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None: "required": ["x"], } assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" - assert tool_.invoke( - { - "name": "foo", - "args": {"x": 5, "y": "bar"}, - "id": "123", - "type": "tool_call", - } - ) == ToolMessage("bar", tool_call_id="123", name="foo") + if tool_.output_version == "v0": + expected: Union[ToolMessage, ToolMessageV1] = ToolMessage( + "bar", tool_call_id="123", name="foo" + ) + else: + expected = ToolMessageV1( + "bar", tool_call_id="123", name="foo", id=AnyStr("lc_abc123") + ) + assert ( + tool_.invoke( + { + "name": "foo", + "args": {"x": 5, "y": "bar"}, + "id": "123", + "type": "tool_call", + } + ) + == expected + ) expected_error = ( ValidationError if not isinstance(tool_, InjectedTool) else TypeError ) @@ -1634,7 +1714,12 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None: @pytest.mark.parametrize( "tool_", - [injected_tool_with_schema, InjectedToolWithSchema()], + [ + injected_tool_with_schema, + InjectedToolWithSchema(), + injected_tool_with_schema_v1, + InjectedToolWithSchema(output_version="v1"), + ], ) def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None: assert _schema(tool_.get_input_schema()) == { @@ -1655,14 +1740,25 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None: "required": ["x"], } assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" - assert tool_.invoke( - { - "name": "foo", - "args": {"x": 5, "y": "bar"}, - "id": "123", - "type": "tool_call", - } - ) == ToolMessage("bar", tool_call_id="123", name="foo") + if tool_.output_version == "v0": + expected: Union[ToolMessage, ToolMessageV1] = ToolMessage( + "bar", tool_call_id="123", name="foo" + ) + else: + expected = ToolMessageV1( + "bar", tool_call_id="123", name="foo", id=AnyStr("lc_abc123") + ) + assert ( + tool_.invoke( + { + "name": "foo", + "args": {"x": 5, "y": "bar"}, + "id": "123", + "type": "tool_call", + } + ) + == expected + ) expected_error = ( ValidationError if not isinstance(tool_, InjectedTool) else TypeError ) @@ -1680,8 +1776,9 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None: } -def test_tool_injected_arg() -> None: - tool_ = injected_tool +@pytest.mark.parametrize("output_version", ["v0", "v1"]) +def test_tool_injected_arg(output_version: Literal["v0", "v1"]) -> None: + tool_ = injected_tool if output_version == "v0" else injected_tool_v1 assert _schema(tool_.get_input_schema()) == { "title": "foo", "description": "Foo.", @@ -1700,14 +1797,25 @@ def test_tool_injected_arg() -> None: "required": ["x"], } assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" - assert tool_.invoke( - { - "name": "foo", - "args": {"x": 5, "y": "bar"}, - "id": "123", - "type": "tool_call", - } - ) == ToolMessage("bar", tool_call_id="123", name="foo") + if output_version == "v0": + expected: Union[ToolMessage, ToolMessageV1] = ToolMessage( + "bar", tool_call_id="123", name="foo" + ) + else: + expected = ToolMessageV1( + "bar", tool_call_id="123", name="foo", id=AnyStr("lc_abc123") + ) + assert ( + tool_.invoke( + { + "name": "foo", + "args": {"x": 5, "y": "bar"}, + "id": "123", + "type": "tool_call", + } + ) + == expected + ) expected_error = ( ValidationError if not isinstance(tool_, InjectedTool) else TypeError ) @@ -1725,7 +1833,8 @@ def test_tool_injected_arg() -> None: } -def test_tool_inherited_injected_arg() -> None: +@pytest.mark.parametrize("output_version", ["v0", "v1"]) +def test_tool_inherited_injected_arg(output_version: Literal["v0", "v1"]) -> None: class BarSchema(BaseModel): """bar.""" @@ -1746,7 +1855,7 @@ def test_tool_inherited_injected_arg() -> None: def _run(self, x: int, y: str) -> Any: return y - tool_ = InheritedInjectedArgTool() + tool_ = InheritedInjectedArgTool(output_version=output_version) assert tool_.get_input_schema().model_json_schema() == { "title": "FooSchema", # Matches the title from the provided schema "description": "foo.", @@ -1766,14 +1875,25 @@ def test_tool_inherited_injected_arg() -> None: "required": ["x"], } assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" - assert tool_.invoke( - { - "name": "foo", - "args": {"x": 5, "y": "bar"}, - "id": "123", - "type": "tool_call", - } - ) == ToolMessage("bar", tool_call_id="123", name="foo") + if output_version == "v0": + expected: Union[ToolMessage, ToolMessageV1] = ToolMessage( + "bar", tool_call_id="123", name="foo" + ) + else: + expected = ToolMessageV1( + "bar", tool_call_id="123", name="foo", id=AnyStr("lc_abc123") + ) + assert ( + tool_.invoke( + { + "name": "foo", + "args": {"x": 5, "y": "bar"}, + "id": "123", + "type": "tool_call", + } + ) + == expected + ) expected_error = ( ValidationError if not isinstance(tool_, InjectedTool) else TypeError ) @@ -2133,7 +2253,8 @@ def test_tool_annotations_preserved() -> None: assert schema.__annotations__ == expected_type_hints -def test_create_retriever_tool() -> None: +@pytest.mark.parametrize("output_version", ["v0", "v1"]) +def test_create_retriever_tool(output_version: Literal["v0", "v1"]) -> None: class MyRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun @@ -2142,21 +2263,36 @@ def test_create_retriever_tool() -> None: retriever = MyRetriever() retriever_tool = tools.create_retriever_tool( - retriever, "retriever_tool_content", "Retriever Tool Content" + retriever, + "retriever_tool_content", + "Retriever Tool Content", + output_version=output_version, ) assert isinstance(retriever_tool, BaseTool) assert retriever_tool.name == "retriever_tool_content" assert retriever_tool.description == "Retriever Tool Content" assert retriever_tool.invoke("bar") == "foo bar\n\nbar" - assert retriever_tool.invoke( - ToolCall( - name="retriever_tool_content", - args={"query": "bar"}, - id="123", - type="tool_call", + if output_version == "v0": + expected: Union[ToolMessage, ToolMessageV1] = ToolMessage( + "foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content" ) - ) == ToolMessage( - "foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content" + else: + expected = ToolMessageV1( + "foo bar\n\nbar", + tool_call_id="123", + name="retriever_tool_content", + id=AnyStr("lc_abc123"), + ) + assert ( + retriever_tool.invoke( + ToolCall( + name="retriever_tool_content", + args={"query": "bar"}, + id="123", + type="tool_call", + ) + ) + == expected ) retriever_tool_artifact = tools.create_retriever_tool( @@ -2164,23 +2300,37 @@ def test_create_retriever_tool() -> None: "retriever_tool_artifact", "Retriever Tool Artifact", response_format="content_and_artifact", + output_version=output_version, ) assert isinstance(retriever_tool_artifact, BaseTool) assert retriever_tool_artifact.name == "retriever_tool_artifact" assert retriever_tool_artifact.description == "Retriever Tool Artifact" assert retriever_tool_artifact.invoke("bar") == "foo bar\n\nbar" - assert retriever_tool_artifact.invoke( - ToolCall( + if output_version == "v0": + expected = ToolMessage( + "foo bar\n\nbar", + artifact=[Document(page_content="foo bar"), Document(page_content="bar")], + tool_call_id="123", name="retriever_tool_artifact", - args={"query": "bar"}, - id="123", - type="tool_call", ) - ) == ToolMessage( - "foo bar\n\nbar", - artifact=[Document(page_content="foo bar"), Document(page_content="bar")], - tool_call_id="123", - name="retriever_tool_artifact", + else: + expected = ToolMessageV1( + "foo bar\n\nbar", + artifact=[Document(page_content="foo bar"), Document(page_content="bar")], + tool_call_id="123", + name="retriever_tool_artifact", + id=AnyStr("lc_abc123"), + ) + assert ( + retriever_tool_artifact.invoke( + ToolCall( + name="retriever_tool_artifact", + args={"query": "bar"}, + id="123", + type="tool_call", + ) + ) + == expected ) @@ -2313,6 +2463,45 @@ def test_tool_injected_tool_call_id() -> None: ) == ToolMessage(0, tool_call_id="bar") # type: ignore[arg-type] +def test_tool_injected_tool_call_id_v1() -> None: + @tool + def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessageV1: + """Foo.""" + return ToolMessageV1(str(x), tool_call_id=tool_call_id) + + assert foo.invoke( + { + "type": "tool_call", + "args": {"x": 0}, + "name": "foo", + "id": "bar", + } + ) == ToolMessageV1("0", tool_call_id="bar", id=AnyStr("lc_abc123")) + + with pytest.raises( + ValueError, + match="When tool includes an InjectedToolCallId argument, " + "tool must always be invoked with a full model ToolCall", + ): + assert foo.invoke({"x": 0}) + + @tool + def foo2( + x: int, tool_call_id: Annotated[str, InjectedToolCallId()] + ) -> ToolMessageV1: + """Foo.""" + return ToolMessageV1(str(x), tool_call_id=tool_call_id) + + assert foo2.invoke( + { + "type": "tool_call", + "args": {"x": 0}, + "name": "foo", + "id": "bar", + } + ) == ToolMessageV1("0", tool_call_id="bar", id=AnyStr("lc_abc123")) + + def test_tool_uninjected_tool_call_id() -> None: @tool def foo(x: int, tool_call_id: str) -> ToolMessage: @@ -2332,6 +2521,25 @@ def test_tool_uninjected_tool_call_id() -> None: ) == ToolMessage(0, tool_call_id="zap") # type: ignore[arg-type] +def test_tool_uninjected_tool_call_id_v1() -> None: + @tool + def foo(x: int, tool_call_id: str) -> ToolMessageV1: + """Foo.""" + return ToolMessageV1(str(x), tool_call_id=tool_call_id) + + with pytest.raises(ValueError, match="1 validation error for foo"): + foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}) + + assert foo.invoke( + { + "type": "tool_call", + "args": {"x": 0, "tool_call_id": "zap"}, + "name": "foo", + "id": "bar", + } + ) == ToolMessageV1("0", tool_call_id="zap", id=AnyStr("lc_abc123")) + + def test_tool_return_output_mixin() -> None: class Bar(ToolOutputMixin): def __init__(self, x: int) -> None: @@ -2457,6 +2665,19 @@ def test_empty_string_tool_call_id() -> None: ) +def test_empty_string_tool_call_id_v1() -> None: + @tool(output_version="v1") + def foo(x: int) -> str: + """Foo.""" + return "hi" + + assert foo.invoke( + {"type": "tool_call", "args": {"x": 0}, "id": ""} + ) == ToolMessageV1( + content="hi", name="foo", tool_call_id="", id=AnyStr("lc_abc123") + ) + + def test_tool_decorator_description() -> None: # test basic tool @tool