core[minor], integrations...[patch]: Support ToolCall as Tool input and ToolMessage as Tool output (#24038)

Changes:
- ToolCall, InvalidToolCall and ToolCallChunk can all accept a "type"
parameter now
- LLM integration packages add "type" to all the above
- Tool supports ToolCall inputs that have "type" specified
- Tool outputs ToolMessage when a ToolCall is passed as input
- Tools can separately specify ToolMessage.content and
ToolMessage.raw_output
- Tools emit events for validation errors (using on_tool_error and
on_tool_end)

Example:
```python
@tool("structured_api", response_format="content_and_raw_output")
def _mock_structured_tool_with_raw_output(
    arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
    """A Structured Tool"""
    return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}


def test_tool_call_input_tool_message_with_raw_output() -> None:
    tool_call: Dict = {
        "name": "structured_api",
        "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
        "id": "123",
        "type": "tool_call",
    }
    expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
    tool = _mock_structured_tool_with_raw_output
    actual = tool.invoke(tool_call)
    assert actual == expected

    tool_call.pop("type")
    with pytest.raises(ValidationError):
        tool.invoke(tool_call)

    actual_content = tool.invoke(tool_call["args"])
    assert actual_content == expected.content
```

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur
2024-07-11 14:54:02 -07:00
committed by GitHub
parent eeb996034b
commit 5fd1e67808
22 changed files with 647 additions and 327 deletions

View File

@@ -317,6 +317,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -419,6 +426,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -908,6 +922,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -1010,6 +1031,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',

View File

@@ -674,6 +674,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -776,6 +783,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',

View File

@@ -5577,6 +5577,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -5701,6 +5708,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -6237,6 +6251,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -6361,6 +6382,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -6834,6 +6862,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -6936,6 +6971,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -7444,6 +7486,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -7568,6 +7617,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -8068,6 +8124,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -8203,6 +8266,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -8683,6 +8753,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -8785,6 +8862,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -9238,6 +9322,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -9340,6 +9431,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -9880,6 +9978,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@@ -10004,6 +10109,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',

View File

@@ -8,7 +8,7 @@ import textwrap
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
import pytest
from typing_extensions import Annotated, TypedDict
@@ -17,6 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolMessage
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
from langchain_core.tools import (
@@ -1067,6 +1068,65 @@ def test_tool_annotated_descriptions() -> None:
}
def test_tool_call_input_tool_message_output() -> None:
tool_call = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
"type": "tool_call",
}
tool = _MockStructuredTool()
expected = ToolMessage("1 True {'img': 'base64string...'}", tool_call_id="123")
actual = tool.invoke(tool_call)
assert actual == expected
tool_call.pop("type")
with pytest.raises(ValidationError):
tool.invoke(tool_call)
class _MockStructuredToolWithRawOutput(BaseTool):
name: str = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description: str = "A Structured Tool"
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
def _run(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@tool("structured_api", response_format="content_and_raw_output")
def _mock_structured_tool_with_raw_output(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
"""A Structured Tool"""
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@pytest.mark.parametrize(
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_raw_output]
)
def test_tool_call_input_tool_message_with_raw_output(tool: BaseTool) -> None:
tool_call: Dict = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
"type": "tool_call",
}
expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
actual = tool.invoke(tool_call)
assert actual == expected
tool_call.pop("type")
with pytest.raises(ValidationError):
tool.invoke(tool_call)
actual_content = tool.invoke(tool_call["args"])
assert actual_content == expected.content
def test_convert_from_runnable_dict() -> None:
# Test with typed dict input
class Args(TypedDict):