mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
core[patch]: return ToolMessage from tool (#28605)
This commit is contained in:
@@ -31,6 +31,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.messages.tool import ToolOutputMixin
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
@@ -46,6 +47,7 @@ from langchain_core.tools import (
|
||||
)
|
||||
from langchain_core.tools.base import (
|
||||
InjectedToolArg,
|
||||
InjectedToolCallId,
|
||||
SchemaAnnotationError,
|
||||
_is_message_content_block,
|
||||
_is_message_content_type,
|
||||
@@ -856,6 +858,7 @@ def test_validation_error_handling_non_validation_error(
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, dict],
|
||||
tool_call_id: Optional[str],
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -920,6 +923,7 @@ async def test_async_validation_error_handling_non_validation_error(
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, dict],
|
||||
tool_call_id: Optional[str],
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2110,3 +2114,63 @@ def test_injected_arg_with_complex_type() -> None:
|
||||
return foo.value
|
||||
|
||||
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore
|
||||
|
||||
|
||||
def test_tool_injected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
|
||||
"""foo"""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||
|
||||
assert foo.invoke(
|
||||
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert foo.invoke({"x": 0})
|
||||
|
||||
@tool
|
||||
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
|
||||
"""foo"""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||
|
||||
assert foo2.invoke(
|
||||
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
||||
|
||||
|
||||
def test_tool_uninjected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: str) -> ToolMessage:
|
||||
"""foo"""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
|
||||
|
||||
assert foo.invoke(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"args": {"x": 0, "tool_call_id": "zap"},
|
||||
"name": "foo",
|
||||
"id": "bar",
|
||||
}
|
||||
) == ToolMessage(0, tool_call_id="zap") # type: ignore
|
||||
|
||||
|
||||
def test_tool_return_output_mixin() -> None:
|
||||
class Bar(ToolOutputMixin):
|
||||
def __init__(self, x: int) -> None:
|
||||
self.x = x
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, self.__class__) and self.x == other.x
|
||||
|
||||
@tool
|
||||
def foo(x: int) -> Bar:
|
||||
"""Foo."""
|
||||
return Bar(x=x)
|
||||
|
||||
assert foo.invoke(
|
||||
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||
) == Bar(x=0)
|
||||
|
Reference in New Issue
Block a user