core[patch]: return ToolMessage from tool (#28605)

This commit is contained in:
Bagatur
2024-12-10 01:59:38 -08:00
committed by GitHub
parent d0e95971f5
commit e24f86e55f
4 changed files with 158 additions and 26 deletions

View File

@@ -31,6 +31,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.runnables import (
Runnable,
RunnableConfig,
@@ -46,6 +47,7 @@ from langchain_core.tools import (
)
from langchain_core.tools.base import (
InjectedToolArg,
InjectedToolCallId,
SchemaAnnotationError,
_is_message_content_block,
_is_message_content_type,
@@ -856,6 +858,7 @@ def test_validation_error_handling_non_validation_error(
def _parse_input(
self,
tool_input: Union[str, dict],
tool_call_id: Optional[str],
) -> Union[str, dict[str, Any]]:
raise NotImplementedError
@@ -920,6 +923,7 @@ async def test_async_validation_error_handling_non_validation_error(
def _parse_input(
self,
tool_input: Union[str, dict],
tool_call_id: Optional[str],
) -> Union[str, dict[str, Any]]:
raise NotImplementedError
@@ -2110,3 +2114,63 @@ def test_injected_arg_with_complex_type() -> None:
return foo.value
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore
def test_tool_injected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
assert foo.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == ToolMessage(0, tool_call_id="bar") # type: ignore
with pytest.raises(ValueError):
assert foo.invoke({"x": 0})
@tool
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
assert foo2.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == ToolMessage(0, tool_call_id="bar") # type: ignore
def test_tool_uninjected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: str) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
with pytest.raises(ValueError):
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
assert foo.invoke(
{
"type": "tool_call",
"args": {"x": 0, "tool_call_id": "zap"},
"name": "foo",
"id": "bar",
}
) == ToolMessage(0, tool_call_id="zap") # type: ignore
def test_tool_return_output_mixin() -> None:
class Bar(ToolOutputMixin):
def __init__(self, x: int) -> None:
self.x = x
def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.x == other.x
@tool
def foo(x: int) -> Bar:
"""Foo."""
return Bar(x=x)
assert foo.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == Bar(x=0)