mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
core[patch]: return ToolMessage from tool (#28605)
This commit is contained in:
parent
d0e95971f5
commit
e24f86e55f
@ -9,7 +9,16 @@ from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_co
|
||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||
|
||||
|
||||
class ToolMessage(BaseMessage):
|
||||
class ToolOutputMixin:
|
||||
"""Mixin for objects that tools can return directly.
|
||||
|
||||
If a custom BaseTool is invoked with a ToolCall and the output of custom code is
|
||||
not an instance of ToolOutputMixin, the output will automatically be coerced to a
|
||||
string and wrapped in a ToolMessage.
|
||||
"""
|
||||
|
||||
|
||||
class ToolMessage(BaseMessage, ToolOutputMixin):
|
||||
"""Message for passing the result of executing a tool back to a model.
|
||||
|
||||
ToolMessages contain the result of a tool invocation. Typically, the result
|
||||
|
@ -45,7 +45,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManager,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolOutputMixin
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
@ -494,7 +494,9 @@ class ChildTool(BaseTool):
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]:
|
||||
def _parse_input(
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
"""Convert tool input to a pydantic model.
|
||||
|
||||
Args:
|
||||
@ -512,9 +514,39 @@ class ChildTool(BaseTool):
|
||||
else:
|
||||
if input_args is not None:
|
||||
if issubclass(input_args, BaseModel):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if (
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if (
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.parse_obj(tool_input)
|
||||
result_dict = result.dict()
|
||||
else:
|
||||
@ -570,8 +602,10 @@ class ChildTool(BaseTool):
|
||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
|
||||
tool_input = self._parse_input(tool_input)
|
||||
def _to_args_and_kwargs(
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
) -> tuple[tuple, dict]:
|
||||
tool_input = self._parse_input(tool_input, tool_call_id)
|
||||
# For backwards compatibility, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
if isinstance(tool_input, str):
|
||||
@ -648,10 +682,9 @@ class ChildTool(BaseTool):
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
||||
if signature(self._run).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
|
||||
if config_param := _get_runnable_config_param(self._run):
|
||||
tool_kwargs[config_param] = config
|
||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||
@ -755,7 +788,7 @@ class ChildTool(BaseTool):
|
||||
artifact = None
|
||||
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
||||
try:
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
@ -889,9 +922,14 @@ def _prep_run_args(
|
||||
|
||||
|
||||
def _format_output(
|
||||
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
|
||||
) -> Union[ToolMessage, Any]:
|
||||
if tool_call_id:
|
||||
content: Any,
|
||||
artifact: Any,
|
||||
tool_call_id: Optional[str],
|
||||
name: str,
|
||||
status: str,
|
||||
) -> Union[ToolOutputMixin, Any]:
|
||||
if isinstance(content, ToolOutputMixin) or not tool_call_id:
|
||||
return content
|
||||
if not _is_message_content_type(content):
|
||||
content = _stringify(content)
|
||||
return ToolMessage(
|
||||
@ -901,8 +939,6 @@ def _format_output(
|
||||
name=name,
|
||||
status=status,
|
||||
)
|
||||
else:
|
||||
return content
|
||||
|
||||
|
||||
def _is_message_content_type(obj: Any) -> bool:
|
||||
@ -954,10 +990,31 @@ class InjectedToolArg:
|
||||
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
|
||||
|
||||
|
||||
def _is_injected_arg_type(type_: type) -> bool:
|
||||
class InjectedToolCallId(InjectedToolArg):
|
||||
r'''Annotation for injecting the tool_call_id.
|
||||
|
||||
Example:
|
||||
..code-block:: python
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool, InjectedToolCallID
|
||||
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallID]) -> ToolMessage:
|
||||
"""Return x."""
|
||||
return ToolMessage(str(x), artifact=x, name="foo", tool_call_id=tool_call_id)
|
||||
''' # noqa: E501
|
||||
|
||||
|
||||
def _is_injected_arg_type(
|
||||
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
|
||||
) -> bool:
|
||||
injected_type = injected_type or InjectedToolArg
|
||||
return any(
|
||||
isinstance(arg, InjectedToolArg)
|
||||
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
|
||||
isinstance(arg, injected_type)
|
||||
or (isinstance(arg, type) and issubclass(arg, injected_type))
|
||||
for arg in get_args(type_)[1:]
|
||||
)
|
||||
|
||||
|
@ -62,9 +62,11 @@ class Tool(BaseTool):
|
||||
# assume it takes a single string input.
|
||||
return {"tool_input": {"type": "string"}}
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
|
||||
def _to_args_and_kwargs(
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
) -> tuple[tuple, dict]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
args, kwargs = super()._to_args_and_kwargs(tool_input)
|
||||
args, kwargs = super()._to_args_and_kwargs(tool_input, tool_call_id)
|
||||
# For backwards compatibility. The tool must be run with a single input
|
||||
all_args = list(args) + list(kwargs.values())
|
||||
if len(all_args) != 1:
|
||||
|
@ -31,6 +31,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.messages.tool import ToolOutputMixin
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
@ -46,6 +47,7 @@ from langchain_core.tools import (
|
||||
)
|
||||
from langchain_core.tools.base import (
|
||||
InjectedToolArg,
|
||||
InjectedToolCallId,
|
||||
SchemaAnnotationError,
|
||||
_is_message_content_block,
|
||||
_is_message_content_type,
|
||||
@ -856,6 +858,7 @@ def test_validation_error_handling_non_validation_error(
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, dict],
|
||||
tool_call_id: Optional[str],
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -920,6 +923,7 @@ async def test_async_validation_error_handling_non_validation_error(
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, dict],
|
||||
tool_call_id: Optional[str],
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -2110,3 +2114,63 @@ def test_injected_arg_with_complex_type() -> None:
|
||||
return foo.value
|
||||
|
||||
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore
|
||||
|
||||
|
||||
def test_tool_injected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
|
||||
"""foo"""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||
|
||||
assert foo.invoke(
|
||||
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert foo.invoke({"x": 0})
|
||||
|
||||
@tool
|
||||
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
|
||||
"""foo"""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||
|
||||
assert foo2.invoke(
|
||||
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore
|
||||
|
||||
|
||||
def test_tool_uninjected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: str) -> ToolMessage:
|
||||
"""foo"""
|
||||
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})
|
||||
|
||||
assert foo.invoke(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"args": {"x": 0, "tool_call_id": "zap"},
|
||||
"name": "foo",
|
||||
"id": "bar",
|
||||
}
|
||||
) == ToolMessage(0, tool_call_id="zap") # type: ignore
|
||||
|
||||
|
||||
def test_tool_return_output_mixin() -> None:
|
||||
class Bar(ToolOutputMixin):
|
||||
def __init__(self, x: int) -> None:
|
||||
self.x = x
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, self.__class__) and self.x == other.x
|
||||
|
||||
@tool
|
||||
def foo(x: int) -> Bar:
|
||||
"""Foo."""
|
||||
return Bar(x=x)
|
||||
|
||||
assert foo.invoke(
|
||||
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
|
||||
) == Bar(x=0)
|
||||
|
Loading…
Reference in New Issue
Block a user