mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
core[patch]: return ToolMessage from tool (#28605)
This commit is contained in:
parent
d0e95971f5
commit
e24f86e55f
@ -9,7 +9,16 @@ from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_co
|
|||||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||||
|
|
||||||
|
|
||||||
class ToolMessage(BaseMessage):
|
class ToolOutputMixin:
|
||||||
|
"""Mixin for objects that tools can return directly.
|
||||||
|
|
||||||
|
If a custom BaseTool is invoked with a ToolCall and the output of custom code is
|
||||||
|
not an instance of ToolOutputMixin, the output will automatically be coerced to a
|
||||||
|
string and wrapped in a ToolMessage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMessage(BaseMessage, ToolOutputMixin):
|
||||||
"""Message for passing the result of executing a tool back to a model.
|
"""Message for passing the result of executing a tool back to a model.
|
||||||
|
|
||||||
ToolMessages contain the result of a tool invocation. Typically, the result
|
ToolMessages contain the result of a tool invocation. Typically, the result
|
||||||
|
@ -45,7 +45,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManager,
|
CallbackManager,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.tool import ToolCall, ToolMessage
|
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolOutputMixin
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
@ -494,7 +494,9 @@ class ChildTool(BaseTool):
|
|||||||
|
|
||||||
# --- Tool ---
|
# --- Tool ---
|
||||||
|
|
||||||
def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]:
|
def _parse_input(
|
||||||
|
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||||
|
) -> Union[str, dict[str, Any]]:
|
||||||
"""Convert tool input to a pydantic model.
|
"""Convert tool input to a pydantic model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -512,9 +514,39 @@ class ChildTool(BaseTool):
|
|||||||
else:
|
else:
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
if issubclass(input_args, BaseModel):
|
if issubclass(input_args, BaseModel):
|
||||||
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||||
|
if (
|
||||||
|
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||||
|
and k not in tool_input
|
||||||
|
):
|
||||||
|
if tool_call_id is None:
|
||||||
|
msg = (
|
||||||
|
"When tool includes an InjectedToolCallId "
|
||||||
|
"argument, tool must always be invoked with a full "
|
||||||
|
"model ToolCall of the form: {'args': {...}, "
|
||||||
|
"'name': '...', 'type': 'tool_call', "
|
||||||
|
"'tool_call_id': '...'}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
tool_input[k] = tool_call_id
|
||||||
result = input_args.model_validate(tool_input)
|
result = input_args.model_validate(tool_input)
|
||||||
result_dict = result.model_dump()
|
result_dict = result.model_dump()
|
||||||
elif issubclass(input_args, BaseModelV1):
|
elif issubclass(input_args, BaseModelV1):
|
||||||
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||||
|
if (
|
||||||
|
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||||
|
and k not in tool_input
|
||||||
|
):
|
||||||
|
if tool_call_id is None:
|
||||||
|
msg = (
|
||||||
|
"When tool includes an InjectedToolCallId "
|
||||||
|
"argument, tool must always be invoked with a full "
|
||||||
|
"model ToolCall of the form: {'args': {...}, "
|
||||||
|
"'name': '...', 'type': 'tool_call', "
|
||||||
|
"'tool_call_id': '...'}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
tool_input[k] = tool_call_id
|
||||||
result = input_args.parse_obj(tool_input)
|
result = input_args.parse_obj(tool_input)
|
||||||
result_dict = result.dict()
|
result_dict = result.dict()
|
||||||
else:
|
else:
|
||||||
@ -570,8 +602,10 @@ class ChildTool(BaseTool):
|
|||||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||||
|
|
||||||
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
|
def _to_args_and_kwargs(
|
||||||
tool_input = self._parse_input(tool_input)
|
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||||
|
) -> tuple[tuple, dict]:
|
||||||
|
tool_input = self._parse_input(tool_input, tool_call_id)
|
||||||
# For backwards compatibility, if run_input is a string,
|
# For backwards compatibility, if run_input is a string,
|
||||||
# pass as a positional argument.
|
# pass as a positional argument.
|
||||||
if isinstance(tool_input, str):
|
if isinstance(tool_input, str):
|
||||||
@ -648,10 +682,9 @@ class ChildTool(BaseTool):
|
|||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
context = copy_context()
|
||||||
context.run(_set_config_context, child_config)
|
context.run(_set_config_context, child_config)
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
||||||
if signature(self._run).parameters.get("run_manager"):
|
if signature(self._run).parameters.get("run_manager"):
|
||||||
tool_kwargs["run_manager"] = run_manager
|
tool_kwargs["run_manager"] = run_manager
|
||||||
|
|
||||||
if config_param := _get_runnable_config_param(self._run):
|
if config_param := _get_runnable_config_param(self._run):
|
||||||
tool_kwargs[config_param] = config
|
tool_kwargs[config_param] = config
|
||||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||||
@ -755,7 +788,7 @@ class ChildTool(BaseTool):
|
|||||||
artifact = None
|
artifact = None
|
||||||
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
||||||
try:
|
try:
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
context = copy_context()
|
||||||
context.run(_set_config_context, child_config)
|
context.run(_set_config_context, child_config)
|
||||||
@ -889,20 +922,23 @@ def _prep_run_args(
|
|||||||
|
|
||||||
|
|
||||||
def _format_output(
|
def _format_output(
|
||||||
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
|
content: Any,
|
||||||
) -> Union[ToolMessage, Any]:
|
artifact: Any,
|
||||||
if tool_call_id:
|
tool_call_id: Optional[str],
|
||||||
if not _is_message_content_type(content):
|
name: str,
|
||||||
content = _stringify(content)
|
status: str,
|
||||||
return ToolMessage(
|
) -> Union[ToolOutputMixin, Any]:
|
||||||
content,
|
if isinstance(content, ToolOutputMixin) or not tool_call_id:
|
||||||
artifact=artifact,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
name=name,
|
|
||||||
status=status,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return content
|
return content
|
||||||
|
if not _is_message_content_type(content):
|
||||||
|
content = _stringify(content)
|
||||||
|
return ToolMessage(
|
||||||
|
content,
|
||||||
|
artifact=artifact,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
name=name,
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_message_content_type(obj: Any) -> bool:
|
def _is_message_content_type(obj: Any) -> bool:
|
||||||
@ -954,10 +990,31 @@ class InjectedToolArg:
|
|||||||
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
|
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
|
||||||
|
|
||||||
|
|
||||||
def _is_injected_arg_type(type_: type) -> bool:
|
class InjectedToolCallId(InjectedToolArg):
|
||||||
|
r'''Annotation for injecting the tool_call_id.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
..code-block:: python
|
||||||
|
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools import tool, InjectedToolCallID
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallID]) -> ToolMessage:
|
||||||
|
"""Return x."""
|
||||||
|
return ToolMessage(str(x), artifact=x, name="foo", tool_call_id=tool_call_id)
|
||||||
|
''' # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def _is_injected_arg_type(
|
||||||
|
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
|
||||||
|
) -> bool:
|
||||||
|
injected_type = injected_type or InjectedToolArg
|
||||||
return any(
|
return any(
|
||||||
isinstance(arg, InjectedToolArg)
|
isinstance(arg, injected_type)
|
||||||
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
|
or (isinstance(arg, type) and issubclass(arg, injected_type))
|
||||||
for arg in get_args(type_)[1:]
|
for arg in get_args(type_)[1:]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,9 +62,11 @@ class Tool(BaseTool):
|
|||||||
# assume it takes a single string input.
|
# assume it takes a single string input.
|
||||||
return {"tool_input": {"type": "string"}}
|
return {"tool_input": {"type": "string"}}
|
||||||
|
|
||||||
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
|
def _to_args_and_kwargs(
|
||||||
|
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||||
|
) -> tuple[tuple, dict]:
|
||||||
"""Convert tool input to pydantic model."""
|
"""Convert tool input to pydantic model."""
|
||||||
args, kwargs = super()._to_args_and_kwargs(tool_input)
|
args, kwargs = super()._to_args_and_kwargs(tool_input, tool_call_id)
|
||||||
# For backwards compatibility. The tool must be run with a single input
|
# For backwards compatibility. The tool must be run with a single input
|
||||||
all_args = list(args) + list(kwargs.values())
|
all_args = list(args) + list(kwargs.values())
|
||||||
if len(all_args) != 1:
|
if len(all_args) != 1:
|
||||||
|
@ -31,6 +31,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.messages.tool import ToolOutputMixin
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
@ -46,6 +47,7 @@ from langchain_core.tools import (
|
|||||||
)
|
)
|
||||||
from langchain_core.tools.base import (
|
from langchain_core.tools.base import (
|
||||||
InjectedToolArg,
|
InjectedToolArg,
|
||||||
|
InjectedToolCallId,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
_is_message_content_block,
|
_is_message_content_block,
|
||||||
_is_message_content_type,
|
_is_message_content_type,
|
||||||
@ -856,6 +858,7 @@ def test_validation_error_handling_non_validation_error(
|
|||||||
def _parse_input(
|
def _parse_input(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, dict],
|
tool_input: Union[str, dict],
|
||||||
|
tool_call_id: Optional[str],
|
||||||
) -> Union[str, dict[str, Any]]:
|
) -> Union[str, dict[str, Any]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -920,6 +923,7 @@ async def test_async_validation_error_handling_non_validation_error(
|
|||||||
def _parse_input(
|
def _parse_input(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, dict],
|
tool_input: Union[str, dict],
|
||||||
|
tool_call_id: Optional[str],
|
||||||
) -> Union[str, dict[str, Any]]:
|
) -> Union[str, dict[str, Any]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -2110,3 +2114,63 @@ def test_injected_arg_with_complex_type() -> None:
|
|||||||
return foo.value
|
return foo.value
|
||||||
|
|
||||||
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore
|
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_injected_tool_call_id() -> None:
|
||||||
|
@tool
|
||||||
|
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
|
||||||
|
"""foo"""
|
||||||
|
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||||
|
|
||||||
|
assert foo.invoke(
|
||||||
|
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||||
|
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert foo.invoke({"x": 0})
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
|
||||||
|
"""foo"""
|
||||||
|
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||||
|
|
||||||
|
assert foo2.invoke(
|
||||||
|
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||||
|
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_uninjected_tool_call_id() -> None:
|
||||||
|
@tool
|
||||||
|
def foo(x: int, tool_call_id: str) -> ToolMessage:
|
||||||
|
"""foo"""
|
||||||
|
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
|
||||||
|
|
||||||
|
assert foo.invoke(
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"args": {"x": 0, "tool_call_id": "zap"},
|
||||||
|
"name": "foo",
|
||||||
|
"id": "bar",
|
||||||
|
}
|
||||||
|
) == ToolMessage(0, tool_call_id="zap") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_return_output_mixin() -> None:
|
||||||
|
class Bar(ToolOutputMixin):
|
||||||
|
def __init__(self, x: int) -> None:
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
return isinstance(other, self.__class__) and self.x == other.x
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def foo(x: int) -> Bar:
|
||||||
|
"""Foo."""
|
||||||
|
return Bar(x=x)
|
||||||
|
|
||||||
|
assert foo.invoke(
|
||||||
|
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||||
|
) == Bar(x=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user