core[patch]: return ToolMessage from tool (#28605)

This commit is contained in:
Bagatur 2024-12-10 01:59:38 -08:00 committed by GitHub
parent d0e95971f5
commit e24f86e55f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 158 additions and 26 deletions

View File

@ -9,7 +9,16 @@ from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_co
from langchain_core.utils._merge import merge_dicts, merge_obj
class ToolMessage(BaseMessage):
class ToolOutputMixin:
"""Mixin for objects that tools can return directly.
If a custom BaseTool is invoked with a ToolCall and the output of custom code is
not an instance of ToolOutputMixin, the output will automatically be coerced to a
string and wrapped in a ToolMessage.
"""
class ToolMessage(BaseMessage, ToolOutputMixin):
"""Message for passing the result of executing a tool back to a model.
ToolMessages contain the result of a tool invocation. Typically, the result

View File

@ -45,7 +45,7 @@ from langchain_core.callbacks import (
CallbackManager,
Callbacks,
)
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolOutputMixin
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
@ -494,7 +494,9 @@ class ChildTool(BaseTool):
# --- Tool ---
def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]:
def _parse_input(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> Union[str, dict[str, Any]]:
"""Convert tool input to a pydantic model.
Args:
@ -512,9 +514,39 @@ class ChildTool(BaseTool):
else:
if input_args is not None:
if issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
else:
@ -570,8 +602,10 @@ class ChildTool(BaseTool):
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
tool_input = self._parse_input(tool_input)
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
tool_input = self._parse_input(tool_input, tool_call_id)
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
@ -648,10 +682,9 @@ class ChildTool(BaseTool):
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
@ -755,7 +788,7 @@ class ChildTool(BaseTool):
artifact = None
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
@ -889,20 +922,23 @@ def _prep_run_args(
def _format_output(
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
) -> Union[ToolMessage, Any]:
if tool_call_id:
if not _is_message_content_type(content):
content = _stringify(content)
return ToolMessage(
content,
artifact=artifact,
tool_call_id=tool_call_id,
name=name,
status=status,
)
else:
content: Any,
artifact: Any,
tool_call_id: Optional[str],
name: str,
status: str,
) -> Union[ToolOutputMixin, Any]:
if isinstance(content, ToolOutputMixin) or not tool_call_id:
return content
if not _is_message_content_type(content):
content = _stringify(content)
return ToolMessage(
content,
artifact=artifact,
tool_call_id=tool_call_id,
name=name,
status=status,
)
def _is_message_content_type(obj: Any) -> bool:
@ -954,10 +990,31 @@ class InjectedToolArg:
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
def _is_injected_arg_type(type_: type) -> bool:
class InjectedToolCallId(InjectedToolArg):
r'''Annotation for injecting the tool_call_id.
Example:
..code-block:: python
from typing_extensions import Annotated
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool, InjectedToolCallID
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallID]) -> ToolMessage:
"""Return x."""
return ToolMessage(str(x), artifact=x, name="foo", tool_call_id=tool_call_id)
''' # noqa: E501
def _is_injected_arg_type(
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
) -> bool:
injected_type = injected_type or InjectedToolArg
return any(
isinstance(arg, InjectedToolArg)
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
isinstance(arg, injected_type)
or (isinstance(arg, type) and issubclass(arg, injected_type))
for arg in get_args(type_)[1:]
)

View File

@ -62,9 +62,11 @@ class Tool(BaseTool):
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
args, kwargs = super()._to_args_and_kwargs(tool_input, tool_call_id)
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:

View File

@ -31,6 +31,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.runnables import (
Runnable,
RunnableConfig,
@ -46,6 +47,7 @@ from langchain_core.tools import (
)
from langchain_core.tools.base import (
InjectedToolArg,
InjectedToolCallId,
SchemaAnnotationError,
_is_message_content_block,
_is_message_content_type,
@ -856,6 +858,7 @@ def test_validation_error_handling_non_validation_error(
def _parse_input(
self,
tool_input: Union[str, dict],
tool_call_id: Optional[str],
) -> Union[str, dict[str, Any]]:
raise NotImplementedError
@ -920,6 +923,7 @@ async def test_async_validation_error_handling_non_validation_error(
def _parse_input(
self,
tool_input: Union[str, dict],
tool_call_id: Optional[str],
) -> Union[str, dict[str, Any]]:
raise NotImplementedError
@ -2110,3 +2114,63 @@ def test_injected_arg_with_complex_type() -> None:
return foo.value
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore
def test_tool_injected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
assert foo.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == ToolMessage(0, tool_call_id="bar") # type: ignore
with pytest.raises(ValueError):
assert foo.invoke({"x": 0})
@tool
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
assert foo2.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == ToolMessage(0, tool_call_id="bar") # type: ignore
def test_tool_uninjected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: str) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
with pytest.raises(ValueError):
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
assert foo.invoke(
{
"type": "tool_call",
"args": {"x": 0, "tool_call_id": "zap"},
"name": "foo",
"id": "bar",
}
) == ToolMessage(0, tool_call_id="zap") # type: ignore
def test_tool_return_output_mixin() -> None:
class Bar(ToolOutputMixin):
def __init__(self, x: int) -> None:
self.x = x
def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.x == other.x
@tool
def foo(x: int) -> Bar:
"""Foo."""
return Bar(x=x)
assert foo.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == Bar(x=0)