core[minor], integrations...[patch]: Support ToolCall as Tool input and ToolMessage as Tool output (#24038)

Changes:
- ToolCall, InvalidToolCall and ToolCallChunk can all accept a "type"
parameter now
- LLM integration packages add "type" to all the above
- Tool supports ToolCall inputs that have "type" specified
- Tool outputs ToolMessage when a ToolCall is passed as input
- Tools can separately specify ToolMessage.content and
ToolMessage.raw_output
- Tools emit events for validation errors (using on_tool_error and
on_tool_end)

Example:
```python
@tool("structured_api", response_format="content_and_raw_output")
def _mock_structured_tool_with_raw_output(
    arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
    """A Structured Tool"""
    return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}


def test_tool_call_input_tool_message_with_raw_output() -> None:
    tool_call: Dict = {
        "name": "structured_api",
        "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
        "id": "123",
        "type": "tool_call",
    }
    expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
    tool = _mock_structured_tool_with_raw_output
    actual = tool.invoke(tool_call)
    assert actual == expected

    tool_call.pop("type")
    with pytest.raises(ValidationError):
        tool.invoke(tool_call)

    actual_content = tool.invoke(tool_call["args"])
    assert actual_content == expected.content
```

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur 2024-07-11 14:54:02 -07:00 committed by GitHub
parent eeb996034b
commit 5fd1e67808
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 647 additions and 327 deletions

View File

@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None: def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
"""Test that tools defined in this repo accept a run manager argument.""" """Test that tools defined in this repo accept a run manager argument."""
# This wouldn't be necessary if the BaseTool had a strict API. # This wouldn't be necessary if the BaseTool had a strict API.
if cls._run is not BaseTool._arun: if cls._run is not BaseTool._run:
run_func = cls._run run_func = cls._run
params = inspect.signature(run_func).parameters params = inspect.signature(run_func).parameters
assert "run_manager" in params assert "run_manager" in params

View File

@ -1,7 +1,7 @@
import json import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing_extensions import TypedDict from typing_extensions import NotRequired, TypedDict
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
from langchain_core.utils._merge import merge_dicts, merge_obj from langchain_core.utils._merge import merge_dicts, merge_obj
@ -146,6 +146,11 @@ class ToolCall(TypedDict):
An identifier is needed to associate a tool call request with a tool An identifier is needed to associate a tool call request with a tool
call result in events when multiple concurrent tool calls are made. call result in events when multiple concurrent tool calls are made.
""" """
type: NotRequired[Literal["tool_call"]]
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
return ToolCall(name=name, args=args, id=id, type="tool_call")
class ToolCallChunk(TypedDict): class ToolCallChunk(TypedDict):
@ -176,6 +181,19 @@ class ToolCallChunk(TypedDict):
"""An identifier associated with the tool call.""" """An identifier associated with the tool call."""
index: Optional[int] index: Optional[int]
"""The index of the tool call in a sequence.""" """The index of the tool call in a sequence."""
type: NotRequired[Literal["tool_call_chunk"]]
def tool_call_chunk(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
index: Optional[int] = None,
) -> ToolCallChunk:
return ToolCallChunk(
name=name, args=args, id=id, index=index, type="tool_call_chunk"
)
class InvalidToolCall(TypedDict): class InvalidToolCall(TypedDict):
@ -193,6 +211,19 @@ class InvalidToolCall(TypedDict):
"""An identifier associated with the tool call.""" """An identifier associated with the tool call."""
error: Optional[str] error: Optional[str]
"""An error message associated with the tool call.""" """An error message associated with the tool call."""
type: NotRequired[Literal["invalid_tool_call"]]
def invalid_tool_call(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
error: Optional[str] = None,
) -> InvalidToolCall:
return InvalidToolCall(
name=name, args=args, id=id, error=error, type="invalid_tool_call"
)
def default_tool_parser( def default_tool_parser(

View File

@ -5,6 +5,12 @@ from typing import Any, Dict, List, Optional, Type
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import (
invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError from langchain_core.pydantic_v1 import BaseModel, ValidationError
@ -59,6 +65,7 @@ def parse_tool_call(
} }
if return_id: if return_id:
parsed["id"] = raw_tool_call.get("id") parsed["id"] = raw_tool_call.get("id")
parsed = create_tool_call(**parsed) # type: ignore
return parsed return parsed
@ -75,7 +82,7 @@ def make_invalid_tool_call(
Returns: Returns:
An InvalidToolCall instance with the error message. An InvalidToolCall instance with the error message.
""" """
return InvalidToolCall( return invalid_tool_call(
name=raw_tool_call["function"]["name"], name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"], args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"), id=raw_tool_call.get("id"),

View File

@ -21,6 +21,7 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import json
import textwrap import textwrap
import uuid import uuid
import warnings import warnings
@ -34,6 +35,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
List, List,
Literal,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
@ -42,7 +44,7 @@ from typing import (
get_type_hints, get_type_hints,
) )
from typing_extensions import Annotated, get_args, get_origin from typing_extensions import Annotated, cast, get_args, get_origin
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -56,6 +58,7 @@ from langchain_core.callbacks.manager import (
Callbacks, Callbacks,
) )
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.prompts import ( from langchain_core.prompts import (
BasePromptTemplate, BasePromptTemplate,
PromptTemplate, PromptTemplate,
@ -306,7 +309,7 @@ class ToolException(Exception):
pass pass
class BaseTool(RunnableSerializable[Union[str, Dict], Any]): class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
"""Interface LangChain tools must implement.""" """Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None: def __init_subclass__(cls, **kwargs: Any) -> None:
@ -378,6 +381,14 @@ class ChildTool(BaseTool):
] = False ] = False
"""Handle the content of the ValidationError thrown.""" """Handle the content of the ValidationError thrown."""
response_format: Literal["content", "content_and_raw_output"] = "content"
"""The tool response format.
If "content" then the output of the tool is interpreted as the contents of a
ToolMessage. If "content_and_raw_output" then the output is expected to be a
two-tuple corresponding to the (content, raw_output) of a ToolMessage.
"""
class Config(Serializable.Config): class Config(Serializable.Config):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -410,46 +421,25 @@ class ChildTool(BaseTool):
def invoke( def invoke(
self, self,
input: Union[str, Dict], input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
config = ensure_config(config) tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return self.run( return self.run(tool_input, **kwargs)
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
**kwargs,
)
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, Dict], input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
config = ensure_config(config) tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return await self.arun( return await self.arun(tool_input, **kwargs)
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
**kwargs,
)
# --- Tool --- # --- Tool ---
def _parse_input( def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model.""" """Convert tool input to pydantic model."""
input_args = self.args_schema input_args = self.args_schema
if isinstance(tool_input, str): if isinstance(tool_input, str):
@ -465,7 +455,7 @@ class ChildTool(BaseTool):
for k, v in result.dict().items() for k, v in result.dict().items()
if k in tool_input if k in tool_input
} }
return tool_input return tool_input
@root_validator(pre=True) @root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
@ -479,30 +469,27 @@ class ChildTool(BaseTool):
return values return values
@abstractmethod @abstractmethod
def _run( def _run(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool. """Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing, to child implementations to enable tracing.
""" """
async def _arun( async def _arun(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool asynchronously. """Use the tool asynchronously.
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
to child implementations to enable tracing, to child implementations to enable tracing.
""" """
if kwargs.get("run_manager") and signature(self._run).parameters.get(
"run_manager"
):
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs) return await run_in_executor(None, self._run, *args, **kwargs)
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
tool_input = self._parse_input(tool_input)
# For backwards compatibility, if run_input is a string, # For backwards compatibility, if run_input is a string,
# pass as a positional argument. # pass as a positional argument.
if isinstance(tool_input, str): if isinstance(tool_input, str):
@ -523,24 +510,20 @@ class ChildTool(BaseTool):
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool.""" """Run the tool."""
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, callbacks,
self.callbacks, self.callbacks,
verbose_, self.verbose or bool(verbose),
tags, tags,
self.tags, self.tags,
metadata, metadata,
self.metadata, self.metadata,
) )
# TODO: maybe also pass through run_manager is _run supports kwargs
new_arg_supported = signature(self._run).parameters.get("run_manager")
run_manager = callback_manager.on_tool_start( run_manager = callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
@ -550,67 +533,52 @@ class ChildTool(BaseTool):
# Inputs by definition should always be dicts. # Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated, # For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead # but if it is we will send a `None` value to the callback instead
# And will need to address issue via a patch. # TODO: will need to address issue via a patch.
inputs=None if isinstance(tool_input, str) else tool_input, inputs=tool_input if isinstance(tool_input, dict) else None,
**kwargs, **kwargs,
) )
content = None
raw_output = None
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try: try:
child_config = patch_config( child_config = patch_config(config, callbacks=run_manager.get_child())
config,
callbacks=run_manager.get_child(),
)
context = copy_context() context = copy_context()
context.run(_set_config_context, child_config) context.run(_set_config_context, child_config)
parsed_input = self._parse_input(tool_input) tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) if signature(self._run).parameters.get("run_manager"):
observation = ( tool_kwargs["run_manager"] = run_manager
context.run( response = context.run(self._run, *tool_args, **tool_kwargs)
self._run, *tool_args, run_manager=run_manager, **tool_kwargs if self.response_format == "content_and_raw_output":
) if not isinstance(response, tuple) or len(response) != 2:
if new_arg_supported raise ValueError(
else context.run(self._run, *tool_args, **tool_kwargs) "Since response_format='content_and_raw_output' "
) "a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, raw_output = response
else:
content = response
except ValidationError as e: except ValidationError as e:
if not self.handle_validation_error: if not self.handle_validation_error:
raise e error_to_raise = e
elif isinstance(self.handle_validation_error, bool):
observation = "Tool input validation error"
elif isinstance(self.handle_validation_error, str):
observation = self.handle_validation_error
elif callable(self.handle_validation_error):
observation = self.handle_validation_error(e)
else: else:
raise ValueError( content = _handle_validation_error(e, flag=self.handle_validation_error)
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {self.handle_validation_error}"
)
return observation
except ToolException as e: except ToolException as e:
if not self.handle_tool_error: if not self.handle_tool_error:
run_manager.on_tool_error(e) error_to_raise = e
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else: else:
raise ValueError( content = _handle_tool_error(e, flag=self.handle_tool_error)
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
return observation
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
run_manager.on_tool_error(e) error_to_raise = e
raise e
else: if error_to_raise:
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs) run_manager.on_tool_error(error_to_raise)
return observation raise error_to_raise
output = _format_output(content, raw_output, tool_call_id)
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
async def arun( async def arun(
self, self,
@ -625,99 +593,80 @@ class ChildTool(BaseTool):
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, callbacks,
self.callbacks, self.callbacks,
verbose_, self.verbose or bool(verbose),
tags, tags,
self.tags, self.tags,
metadata, metadata,
self.metadata, self.metadata,
) )
new_arg_supported = signature(self._arun).parameters.get("run_manager")
run_manager = await callback_manager.on_tool_start( run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
name=run_name, name=run_name,
inputs=tool_input,
run_id=run_id, run_id=run_id,
# Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead
# TODO: will need to address issue via a patch.
inputs=tool_input if isinstance(tool_input, dict) else None,
**kwargs, **kwargs,
) )
content = None
raw_output = None
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try: try:
parsed_input = self._parse_input(tool_input) tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
# We then call the tool on the tool input to get an observation child_config = patch_config(config, callbacks=run_manager.get_child())
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
child_config = patch_config(
config,
callbacks=run_manager.get_child(),
)
context = copy_context() context = copy_context()
context.run(_set_config_context, child_config) context.run(_set_config_context, child_config)
coro = ( if self.__class__._arun is BaseTool._arun or signature(
context.run( self._arun
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs ).parameters.get("run_manager"):
) tool_kwargs["run_manager"] = run_manager
if new_arg_supported coro = context.run(self._arun, *tool_args, **tool_kwargs)
else context.run(self._arun, *tool_args, **tool_kwargs)
)
if accepts_context(asyncio.create_task): if accepts_context(asyncio.create_task):
observation = await asyncio.create_task(coro, context=context) # type: ignore response = await asyncio.create_task(coro, context=context) # type: ignore
else: else:
observation = await coro response = await coro
if self.response_format == "content_and_raw_output":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_raw_output' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, raw_output = response
else:
content = response
except ValidationError as e: except ValidationError as e:
if not self.handle_validation_error: if not self.handle_validation_error:
raise e error_to_raise = e
elif isinstance(self.handle_validation_error, bool):
observation = "Tool input validation error"
elif isinstance(self.handle_validation_error, str):
observation = self.handle_validation_error
elif callable(self.handle_validation_error):
observation = self.handle_validation_error(e)
else: else:
raise ValueError( content = _handle_validation_error(e, flag=self.handle_validation_error)
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {self.handle_validation_error}"
)
return observation
except ToolException as e: except ToolException as e:
if not self.handle_tool_error: if not self.handle_tool_error:
await run_manager.on_tool_error(e) error_to_raise = e
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else: else:
raise ValueError( content = _handle_tool_error(e, flag=self.handle_tool_error)
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
await run_manager.on_tool_end(
observation, color="red", name=self.name, **kwargs
)
return observation
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
await run_manager.on_tool_error(e) error_to_raise = e
raise e
else: if error_to_raise:
await run_manager.on_tool_end( await run_manager.on_tool_error(error_to_raise)
observation, color=color, name=self.name, **kwargs raise error_to_raise
)
return observation output = _format_output(content, raw_output, tool_call_id)
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@deprecated("0.1.47", alternative="invoke", removal="0.3.0") @deprecated("0.1.47", alternative="invoke", removal="0.3.0")
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
@ -738,7 +687,7 @@ class Tool(BaseTool):
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, Dict], input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -780,17 +729,10 @@ class Tool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool.""" """Use the tool."""
if self.func: if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks") if run_manager and signature(self.func).parameters.get("callbacks"):
return ( kwargs["callbacks"] = run_manager.get_child()
self.func( return self.func(*args, **kwargs)
*args, raise NotImplementedError("Tool does not support sync invocation.")
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
async def _arun( async def _arun(
self, self,
@ -800,26 +742,13 @@ class Tool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get( if run_manager and signature(self.coroutine).parameters.get("callbacks"):
"callbacks" kwargs["callbacks"] = run_manager.get_child()
) return await self.coroutine(*args, **kwargs)
return (
await self.coroutine( # NOTE: this code is unreachable since _arun is only called if coroutine is not
*args, # None.
callbacks=run_manager.get_child() if run_manager else None, return await super()._arun(*args, run_manager=run_manager, **kwargs)
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
else:
return await run_in_executor(
None,
self._run,
run_manager=run_manager.get_sync() if run_manager else None,
*args,
**kwargs,
)
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
def __init__( def __init__(
@ -870,9 +799,10 @@ class StructuredTool(BaseTool):
# --- Runnable --- # --- Runnable ---
# TODO: Is this needed?
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, Dict], input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -897,45 +827,26 @@ class StructuredTool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool.""" """Use the tool."""
if self.func: if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks") if run_manager and signature(self.func).parameters.get("callbacks"):
return ( kwargs["callbacks"] = run_manager.get_child()
self.func( return self.func(*args, **kwargs)
*args, raise NotImplementedError("StructuredTool does not support sync invocation.")
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
async def _arun( async def _arun(
self, self,
*args: Any, *args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> Any:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get( if run_manager and signature(self.coroutine).parameters.get("callbacks"):
"callbacks" kwargs["callbacks"] = run_manager.get_child()
) return await self.coroutine(*args, **kwargs)
return (
await self.coroutine( # NOTE: this code is unreachable since _arun is only called if coroutine is not
*args, # None.
callbacks=run_manager.get_child() if run_manager else None, return await super()._arun(*args, run_manager=run_manager, **kwargs)
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
return await run_in_executor(
None,
self._run,
run_manager=run_manager.get_sync() if run_manager else None,
*args,
**kwargs,
)
@classmethod @classmethod
def from_function( def from_function(
@ -947,6 +858,8 @@ class StructuredTool(BaseTool):
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True, infer_schema: bool = True,
*,
response_format: Literal["content", "content_and_raw_output"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
error_on_invalid_docstring: bool = False, error_on_invalid_docstring: bool = False,
**kwargs: Any, **kwargs: Any,
@ -963,6 +876,10 @@ class StructuredTool(BaseTool):
return_direct: Whether to return the result directly or as a callback return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature infer_schema: Whether to infer the schema from the function's signature
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_raw_output" then the output is expected to be a two-tuple
corresponding to the (content, raw_output) of a ToolMessage.
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
to parse parameter descriptions from Google Style function docstrings. to parse parameter descriptions from Google Style function docstrings.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
@ -1020,6 +937,7 @@ class StructuredTool(BaseTool):
args_schema=_args_schema, # type: ignore[arg-type] args_schema=_args_schema, # type: ignore[arg-type]
description=description_, description=description_,
return_direct=return_direct, return_direct=return_direct,
response_format=response_format,
**kwargs, **kwargs,
) )
@ -1029,6 +947,7 @@ def tool(
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True, infer_schema: bool = True,
response_format: Literal["content", "content_and_raw_output"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
error_on_invalid_docstring: bool = True, error_on_invalid_docstring: bool = True,
) -> Callable: ) -> Callable:
@ -1042,6 +961,10 @@ def tool(
infer_schema: Whether to infer the schema of the arguments from infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function. accept a dictionary input to its `run()` function.
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_raw_output" then the output is expected to be a two-tuple
corresponding to the (content, raw_output) of a ToolMessage.
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
parse parameter descriptions from Google Style function docstrings. parse parameter descriptions from Google Style function docstrings.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
@ -1064,8 +987,12 @@ def tool(
# Searches the API for the query. # Searches the API for the query.
return return
.. versionadded:: 0.2.14 @tool(response_format="content_and_raw_output")
Parse Google-style docstrings: def search_api(query: str) -> Tuple[str, dict]:
return "partial json of results", {"full": "object of results"}
.. versionadded:: 0.2.14
Parse Google-style docstrings:
.. code-block:: python .. code-block:: python
@ -1179,6 +1106,7 @@ def tool(
return_direct=return_direct, return_direct=return_direct,
args_schema=schema, args_schema=schema,
infer_schema=infer_schema, infer_schema=infer_schema,
response_format=response_format,
parse_docstring=parse_docstring, parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring, error_on_invalid_docstring=error_on_invalid_docstring,
) )
@ -1195,6 +1123,7 @@ def tool(
description=f"{tool_name} tool", description=f"{tool_name} tool",
return_direct=return_direct, return_direct=return_direct,
coroutine=coroutine, coroutine=coroutine,
response_format=response_format,
) )
return _make_tool return _make_tool
@ -1350,6 +1279,103 @@ class BaseToolkit(BaseModel, ABC):
"""Get the tools in the toolkit.""" """Get the tools in the toolkit."""
def _is_tool_call(x: Any) -> bool:
return isinstance(x, dict) and x.get("type") == "tool_call"
def _handle_validation_error(
e: ValidationError,
*,
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
) -> str:
if isinstance(flag, bool):
content = "Tool input validation error"
elif isinstance(flag, str):
content = flag
elif callable(flag):
content = flag(e)
else:
raise ValueError(
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {flag}"
)
return content
def _handle_tool_error(
e: ToolException,
*,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str:
if isinstance(flag, bool):
if e.args:
content = e.args[0]
else:
content = "Tool execution error"
elif isinstance(flag, str):
content = flag
elif callable(flag):
content = flag(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {flag}"
)
return content
def _prep_run_args(
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig],
**kwargs: Any,
) -> Tuple[Union[str, Dict], Dict]:
config = ensure_config(config)
if _is_tool_call(input):
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
tool_input: Union[str, dict] = cast(ToolCall, input)["args"]
else:
tool_call_id = None
tool_input = cast(Union[str, dict], input)
return (
tool_input,
dict(
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
tool_call_id=tool_call_id,
**kwargs,
),
)
def _format_output(
content: Any, raw_output: Any, tool_call_id: Optional[str]
) -> Union[ToolMessage, Any]:
if tool_call_id:
# NOTE: This will fail to stringify lists which aren't actually content blocks
# but whose first element happens to be a string or dict. Tools should avoid
# returning such contents.
if not isinstance(content, str) and not (
isinstance(content, list)
and content
and isinstance(content[0], (str, dict))
):
content = _stringify(content)
return ToolMessage(content, raw_output=raw_output, tool_call_id=tool_call_id)
else:
return content
def _stringify(content: Any) -> str:
try:
return json.dumps(content)
except Exception:
return str(content)
def _get_description_from_runnable(runnable: Runnable) -> str: def _get_description_from_runnable(runnable: Runnable) -> str:
"""Generate a placeholder description of a runnable.""" """Generate a placeholder description of a runnable."""
input_schema = runnable.input_schema.schema() input_schema = runnable.input_schema.schema()

View File

@ -317,6 +317,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -419,6 +426,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -908,6 +922,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -1010,6 +1031,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',

View File

@ -674,6 +674,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -776,6 +783,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',

View File

@ -5577,6 +5577,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -5701,6 +5708,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -6237,6 +6251,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -6361,6 +6382,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -6834,6 +6862,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -6936,6 +6971,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -7444,6 +7486,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -7568,6 +7617,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -8068,6 +8124,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -8203,6 +8266,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -8683,6 +8753,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -8785,6 +8862,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -9238,6 +9322,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -9340,6 +9431,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -9880,6 +9978,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',
@ -10004,6 +10109,13 @@
'title': 'Name', 'title': 'Name',
'type': 'string', 'type': 'string',
}), }),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}), }),
'required': list([ 'required': list([
'name', 'name',

View File

@ -8,7 +8,7 @@ import textwrap
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
import pytest import pytest
from typing_extensions import Annotated, TypedDict from typing_extensions import Annotated, TypedDict
@ -17,6 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.messages import ToolMessage
from langchain_core.pydantic_v1 import BaseModel, ValidationError from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
from langchain_core.tools import ( from langchain_core.tools import (
@ -1067,6 +1068,65 @@ def test_tool_annotated_descriptions() -> None:
} }
def test_tool_call_input_tool_message_output() -> None:
tool_call = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
"type": "tool_call",
}
tool = _MockStructuredTool()
expected = ToolMessage("1 True {'img': 'base64string...'}", tool_call_id="123")
actual = tool.invoke(tool_call)
assert actual == expected
tool_call.pop("type")
with pytest.raises(ValidationError):
tool.invoke(tool_call)
class _MockStructuredToolWithRawOutput(BaseTool):
name: str = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description: str = "A Structured Tool"
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
def _run(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@tool("structured_api", response_format="content_and_raw_output")
def _mock_structured_tool_with_raw_output(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
"""A Structured Tool"""
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@pytest.mark.parametrize(
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_raw_output]
)
def test_tool_call_input_tool_message_with_raw_output(tool: BaseTool) -> None:
tool_call: Dict = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
"type": "tool_call",
}
expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
actual = tool.invoke(tool_call)
assert actual == expected
tool_call.pop("type")
with pytest.raises(ValidationError):
tool.invoke(tool_call)
actual_content = tool.invoke(tool_call["args"])
assert actual_content == expected.content
def test_convert_from_runnable_dict() -> None: def test_convert_from_runnable_dict() -> None:
# Test with typed dict input # Test with typed dict input
class Args(TypedDict): class Args(TypedDict):

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "aiohttp" name = "aiohttp"
@ -1760,7 +1760,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.2.12" version = "0.2.13"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -1784,7 +1784,7 @@ url = "../core"
[[package]] [[package]]
name = "langchain-openai" name = "langchain-openai"
version = "0.1.14" version = "0.1.15"
description = "An integration package connecting OpenAI and LangChain" description = "An integration package connecting OpenAI and LangChain"
optional = true optional = true
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -1792,7 +1792,7 @@ files = []
develop = true develop = true
[package.dependencies] [package.dependencies]
langchain-core = ">=0.2.2,<0.3" langchain-core = "^0.2.13"
openai = "^1.32.0" openai = "^1.32.0"
tiktoken = ">=0.7,<1" tiktoken = ">=0.7,<1"
@ -1834,13 +1834,13 @@ types-requests = ">=2.31.0.2,<3.0.0.0"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.84" version = "0.1.85"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = "<4.0,>=3.8.1" python-versions = "<4.0,>=3.8.1"
files = [ files = [
{file = "langsmith-0.1.84-py3-none-any.whl", hash = "sha256:01f3c6390dba26c583bac8dd0e551ce3d0509c7f55cad714db0b5c8d36e4c7ff"}, {file = "langsmith-0.1.85-py3-none-any.whl", hash = "sha256:c1f94384f10cea96f7b4d33fd3db7ec180c03c7468877d50846f881d2017ff94"},
{file = "langsmith-0.1.84.tar.gz", hash = "sha256:5220c0439838b9a5bd320fd3686be505c5083dcee22d2452006c23891153bea1"}, {file = "langsmith-0.1.85.tar.gz", hash = "sha256:acff31f9e53efa48586cf8e32f65625a335c74d7c4fa306d1655ac18452296f6"},
] ]
[package.dependencies] [package.dependencies]
@ -2350,13 +2350,13 @@ files = [
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.35.10" version = "1.35.13"
description = "The official Python library for the openai API" description = "The official Python library for the openai API"
optional = true optional = true
python-versions = ">=3.7.1" python-versions = ">=3.7.1"
files = [ files = [
{file = "openai-1.35.10-py3-none-any.whl", hash = "sha256:962cb5c23224b5cbd16078308dabab97a08b0a5ad736a4fdb3dc2ffc44ac974f"}, {file = "openai-1.35.13-py3-none-any.whl", hash = "sha256:36ec3e93e0d1f243f69be85c89b9221a471c3e450dfd9df16c9829e3cdf63e60"},
{file = "openai-1.35.10.tar.gz", hash = "sha256:85966949f4f960f3e4b239a659f9fd64d3a97ecc43c44dc0a044b5c7f11cccc6"}, {file = "openai-1.35.13.tar.gz", hash = "sha256:c684f3945608baf7d2dcc0ef3ee6f3e27e4c66f21076df0b47be45d57e6ae6e4"},
] ]
[package.dependencies] [package.dependencies]
@ -4141,13 +4141,13 @@ urllib3 = ">=2"
[[package]] [[package]]
name = "types-setuptools" name = "types-setuptools"
version = "70.2.0.20240704" version = "70.3.0.20240710"
description = "Typing stubs for setuptools" description = "Typing stubs for setuptools"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "types-setuptools-70.2.0.20240704.tar.gz", hash = "sha256:2f8d28d16ca1607080f9fdf19595bd49c942884b2bbd6529c9b8a9a8fc8db911"}, {file = "types-setuptools-70.3.0.20240710.tar.gz", hash = "sha256:842cbf399812d2b65042c9d6ff35113bbf282dee38794779aa1f94e597bafc35"},
{file = "types_setuptools-70.2.0.20240704-py3-none-any.whl", hash = "sha256:6b892d5441c2ed58dd255724516e3df1db54892fb20597599aea66d04c3e4d7f"}, {file = "types_setuptools-70.3.0.20240710-py3-none-any.whl", hash = "sha256:bd0db2a4b9f2c49ac5564be4e0fb3125c4c46b1f73eafdcbceffa5b005cceca4"},
] ]
[[package]] [[package]]

View File

@ -43,6 +43,7 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import ( from langchain_core.output_parsers import (
JsonOutputKeyToolsParser, JsonOutputKeyToolsParser,
PydanticToolsParser, PydanticToolsParser,
@ -1102,12 +1103,12 @@ def _make_message_chunk_from_anthropic_event(
warnings.warn("Received unexpected tool content block.") warnings.warn("Received unexpected tool content block.")
content_block = event.content_block.model_dump() content_block = event.content_block.model_dump()
content_block["index"] = event.index content_block["index"] = event.index
tool_call_chunk = { tool_call_chunk = create_tool_call_chunk(
"index": event.index, index=event.index,
"id": event.content_block.id, id=event.content_block.id,
"name": event.content_block.name, name=event.content_block.name,
"args": "", args="",
} )
message_chunk = AIMessageChunk( message_chunk = AIMessageChunk(
content=[content_block], content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore tool_call_chunks=[tool_call_chunk], # type: ignore

View File

@ -1,6 +1,7 @@
from typing import Any, List, Optional, Type, Union, cast from typing import Any, List, Optional, Type, Union, cast
from langchain_core.messages import AIMessage, ToolCall from langchain_core.messages import AIMessage, ToolCall
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import BaseGenerationOutputParser from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
@ -79,7 +80,7 @@ def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[Tool
if block["type"] != "tool_use": if block["type"] != "tool_use":
continue continue
tool_calls.append( tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"]) tool_call(name=block["name"], args=block["input"], id=block["id"])
) )
return tool_calls return tool_calls
else: else:

View File

@ -365,10 +365,7 @@ async def test_astreaming() -> None:
def test_tool_use() -> None: def test_tool_use() -> None:
llm = ChatAnthropic( # type: ignore[call-arg] llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
model=MODEL_NAME,
)
llm_with_tools = llm.bind_tools( llm_with_tools = llm.bind_tools(
[ [
{ {
@ -478,6 +475,7 @@ def test_anthropic_with_empty_text_block() -> None:
"name": "type_letter", "name": "type_letter",
"args": {"letter": "d"}, "args": {"letter": "d"},
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk", "id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
"type": "tool_call",
}, },
], ],
), ),

View File

@ -33,8 +33,20 @@ class _Foo2(BaseModel):
def test_tools_output_parser() -> None: def test_tools_output_parser() -> None:
output_parser = ToolsOutputParser() output_parser = ToolsOutputParser()
expected = [ expected = [
{"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}, {
{"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3}, "name": "_Foo1",
"args": {"bar": 0},
"id": "1",
"index": 1,
"type": "tool_call",
},
{
"name": "_Foo2",
"args": {"baz": "a"},
"id": "2",
"index": 3,
"type": "tool_call",
},
] ]
actual = output_parser.parse_result(_RESULT) actual = output_parser.parse_result(_RESULT)
assert expected == actual assert expected == actual
@ -56,7 +68,13 @@ def test_tools_output_parser_args_only() -> None:
def test_tools_output_parser_first_tool_only() -> None: def test_tools_output_parser_first_tool_only() -> None:
output_parser = ToolsOutputParser(first_tool_only=True) output_parser = ToolsOutputParser(first_tool_only=True)
expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1} expected: Any = {
"name": "_Foo1",
"args": {"bar": 0},
"id": "1",
"index": 1,
"type": "tool_call",
}
actual = output_parser.parse_result(_RESULT) actual = output_parser.parse_result(_RESULT)
assert expected == actual assert expected == actual
@ -81,7 +99,14 @@ def test_tools_output_parser_empty_content() -> None:
) )
message = AIMessage( message = AIMessage(
"", "",
tool_calls=[{"name": "ChartType", "args": {"chart_type": "pie"}, "id": "foo"}], tool_calls=[
{
"name": "ChartType",
"args": {"chart_type": "pie"},
"id": "foo",
"type": "tool_call",
}
],
) )
actual = output_parser.invoke(message) actual = output_parser.invoke(message)
expected = ChartType(chart_type="pie") expected = ChartType(chart_type="pie")

View File

@ -9,10 +9,11 @@ import json
import os import os
import re import re
import urllib import urllib
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from io import BytesIO from io import BytesIO
from typing import Any, BinaryIO, Callable, List, Optional from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple
from uuid import uuid4 from uuid import uuid4
import requests import requests
@ -126,6 +127,8 @@ class SessionsPythonREPLTool(BaseTool):
session_id: str = str(uuid4()) session_id: str = str(uuid4())
"""The session ID to use for the code interpreter. Defaults to a random UUID.""" """The session ID to use for the code interpreter. Defaults to a random UUID."""
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
def _build_url(self, path: str) -> str: def _build_url(self, path: str) -> str:
pool_management_endpoint = self.pool_management_endpoint pool_management_endpoint = self.pool_management_endpoint
if not pool_management_endpoint: if not pool_management_endpoint:
@ -164,16 +167,16 @@ class SessionsPythonREPLTool(BaseTool):
properties = response_json.get("properties", {}) properties = response_json.get("properties", {})
return properties return properties
def _run(self, python_code: str) -> Any: def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]:
response = self.execute(python_code) response = self.execute(python_code)
# if the result is an image, remove the base64 data # if the result is an image, remove the base64 data
result = response.get("result") result = deepcopy(response.get("result"))
if isinstance(result, dict): if isinstance(result, dict):
if result.get("type") == "image" and "base64_data" in result: if result.get("type") == "image" and "base64_data" in result:
result.pop("base64_data") result.pop("base64_data")
return json.dumps( content = json.dumps(
{ {
"result": result, "result": result,
"stdout": response.get("stdout"), "stdout": response.get("stdout"),
@ -181,6 +184,7 @@ class SessionsPythonREPLTool(BaseTool):
}, },
indent=2, indent=2,
) )
return content, response
def upload_file( def upload_file(
self, self,

View File

@ -54,6 +54,12 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
ToolMessageChunk, ToolMessageChunk,
) )
from langchain_core.messages.tool import (
ToolCallChunk,
)
from langchain_core.messages.tool import (
tool_call_chunk as create_tool_call_chunk,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
@ -199,6 +205,7 @@ def _convert_chunk_to_message_chunk(
role = cast(str, _dict.get("role")) role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "") content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {} additional_kwargs: Dict = {}
tool_call_chunks: List[ToolCallChunk] = []
if _dict.get("function_call"): if _dict.get("function_call"):
function_call = dict(_dict["function_call"]) function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None: if "name" in function_call and function_call["name"] is None:
@ -206,21 +213,18 @@ def _convert_chunk_to_message_chunk(
additional_kwargs["function_call"] = function_call additional_kwargs["function_call"] = function_call
if raw_tool_calls := _dict.get("tool_calls"): if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls additional_kwargs["tool_calls"] = raw_tool_calls
try: for rtc in raw_tool_calls:
tool_call_chunks = [ try:
{ tool_call_chunks.append(
"name": rtc["function"].get("name"), create_tool_call_chunk(
"args": rtc["function"].get("arguments"), name=rtc["function"].get("name"),
"id": rtc.get("id"), args=rtc["function"].get("arguments"),
"index": rtc["index"], id=rtc.get("id"),
} index=rtc.get("index"),
for rtc in raw_tool_calls )
] )
except KeyError: except KeyError:
pass pass
else:
tool_call_chunks = []
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: elif role == "assistant" or default_class == AIMessageChunk:
@ -237,7 +241,7 @@ def _convert_chunk_to_message_chunk(
return AIMessageChunk( return AIMessageChunk(
content=content, content=content,
additional_kwargs=additional_kwargs, additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type]
) )
elif role == "system" or default_class == SystemMessageChunk: elif role == "system" or default_class == SystemMessageChunk:

View File

@ -53,6 +53,7 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
ToolMessageChunk, ToolMessageChunk,
) )
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import ( from langchain_core.output_parsers import (
JsonOutputParser, JsonOutputParser,
PydanticOutputParser, PydanticOutputParser,
@ -511,19 +512,19 @@ class ChatGroq(BaseChatModel):
generation = chat_result.generations[0] generation = chat_result.generations[0]
message = cast(AIMessage, generation.message) message = cast(AIMessage, generation.message)
tool_call_chunks = [ tool_call_chunks = [
{ create_tool_call_chunk(
"name": rtc["function"].get("name"), name=rtc["function"].get("name"),
"args": rtc["function"].get("arguments"), args=rtc["function"].get("arguments"),
"id": rtc.get("id"), id=rtc.get("id"),
"index": rtc.get("index"), index=rtc.get("index"),
} )
for rtc in message.additional_kwargs.get("tool_calls", []) for rtc in message.additional_kwargs.get("tool_calls", [])
] ]
chunk_ = ChatGenerationChunk( chunk_ = ChatGenerationChunk(
message=AIMessageChunk( message=AIMessageChunk(
content=message.content, content=message.content,
additional_kwargs=message.additional_kwargs, additional_kwargs=message.additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] tool_call_chunks=tool_call_chunks,
usage_metadata=message.usage_metadata, usage_metadata=message.usage_metadata,
), ),
generation_info=generation.generation_info, generation_info=generation.generation_info,

View File

@ -77,6 +77,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername", name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"}, args={"name": "Sally", "hair_color": "green"},
id="call_wm0JY6CdwOMZ4eTxHWUThDNz", id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
type="tool_call",
) )
], ],
) )
@ -112,6 +113,7 @@ def test__convert_dict_to_message_tool_call() -> None:
args="oops", args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz", id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
type="invalid_tool_call",
), ),
], ],
tool_calls=[ tool_calls=[
@ -119,6 +121,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername", name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"}, args={"name": "Sally", "hair_color": "green"},
id="call_abc123", id="call_abc123",
type="tool_call",
), ),
], ],
) )

View File

@ -42,10 +42,12 @@ from langchain_core.messages import (
HumanMessageChunk, HumanMessageChunk,
SystemMessage, SystemMessage,
SystemMessageChunk, SystemMessageChunk,
ToolCallChunk,
ToolMessage, ToolMessage,
ToolMessageChunk, ToolMessageChunk,
convert_to_messages, convert_to_messages,
) )
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
@ -174,6 +176,7 @@ def _convert_delta_to_message_chunk(
role = cast(str, _dict.get("role")) role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "") content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {} additional_kwargs: Dict = {}
tool_call_chunks: List[ToolCallChunk] = []
if _dict.get("function_call"): if _dict.get("function_call"):
function_call = dict(_dict["function_call"]) function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None: if "name" in function_call and function_call["name"] is None:
@ -181,21 +184,18 @@ def _convert_delta_to_message_chunk(
additional_kwargs["function_call"] = function_call additional_kwargs["function_call"] = function_call
if raw_tool_calls := _dict.get("tool_calls"): if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls additional_kwargs["tool_calls"] = raw_tool_calls
try: for rtc in raw_tool_calls:
tool_call_chunks = [ try:
{ tool_call_chunks.append(
"name": rtc["function"].get("name"), create_tool_call_chunk(
"args": rtc["function"].get("arguments"), name=rtc["function"].get("name"),
"id": rtc.get("id"), args=rtc["function"].get("arguments"),
"index": rtc["index"], id=rtc.get("id"),
} index=rtc.get("index"),
for rtc in raw_tool_calls )
] )
except KeyError: except KeyError:
pass pass
else:
tool_call_chunks = []
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: elif role == "assistant" or default_class == AIMessageChunk:

View File

@ -50,6 +50,7 @@ from langchain_core.messages import (
ToolCall, ToolCall,
ToolMessage, ToolMessage,
) )
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import ( from langchain_core.output_parsers import (
JsonOutputParser, JsonOutputParser,
PydanticOutputParser, PydanticOutputParser,
@ -103,19 +104,10 @@ def _convert_mistral_chat_message_to_message(
dict, parse_tool_call(raw_tool_call, return_id=True) dict, parse_tool_call(raw_tool_call, return_id=True)
) )
if not parsed["id"]: if not parsed["id"]:
tool_call_id = uuid.uuid4().hex[:] parsed["id"] = uuid.uuid4().hex[:]
tool_calls.append( tool_calls.append(parsed)
{
**parsed,
**{"id": tool_call_id},
},
)
else:
tool_calls.append(parsed)
except Exception as e: except Exception as e:
invalid_tool_calls.append( invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))
dict(make_invalid_tool_call(raw_tool_call, str(e)))
)
return AIMessage( return AIMessage(
content=content, content=content,
additional_kwargs=additional_kwargs, additional_kwargs=additional_kwargs,
@ -206,12 +198,12 @@ def _convert_chunk_to_message_chunk(
else: else:
tool_call_id = raw_tool_call.get("id") tool_call_id = raw_tool_call.get("id")
tool_call_chunks.append( tool_call_chunks.append(
{ tool_call_chunk(
"name": raw_tool_call["function"].get("name"), name=raw_tool_call["function"].get("name"),
"args": raw_tool_call["function"].get("arguments"), args=raw_tool_call["function"].get("arguments"),
"id": tool_call_id, id=tool_call_id,
"index": raw_tool_call.get("index"), index=raw_tool_call.get("index"),
} )
) )
except KeyError: except KeyError:
pass pass

View File

@ -144,6 +144,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername", name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"}, args={"name": "Sally", "hair_color": "green"},
id="abc123", id="abc123",
type="tool_call",
) )
], ],
) )
@ -178,6 +179,7 @@ def test__convert_dict_to_message_tool_call() -> None:
args="oops", args="oops",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
id="abc123", id="abc123",
type="invalid_tool_call",
), ),
], ],
tool_calls=[ tool_calls=[
@ -185,6 +187,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername", name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"}, args={"name": "Sally", "hair_color": "green"},
id="def456", id="def456",
type="tool_call",
), ),
], ],
) )

View File

@ -63,6 +63,7 @@ from langchain_core.messages import (
ToolMessageChunk, ToolMessageChunk,
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
@ -244,12 +245,12 @@ def _convert_delta_to_message_chunk(
additional_kwargs["tool_calls"] = raw_tool_calls additional_kwargs["tool_calls"] = raw_tool_calls
try: try:
tool_call_chunks = [ tool_call_chunks = [
{ tool_call_chunk(
"name": rtc["function"].get("name"), name=rtc["function"].get("name"),
"args": rtc["function"].get("arguments"), args=rtc["function"].get("arguments"),
"id": rtc.get("id"), id=rtc.get("id"),
"index": rtc["index"], index=rtc["index"],
} )
for rtc in raw_tool_calls for rtc in raw_tool_calls
] ]
except KeyError: except KeyError:

View File

@ -117,6 +117,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername", name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"}, args={"name": "Sally", "hair_color": "green"},
id="call_wm0JY6CdwOMZ4eTxHWUThDNz", id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
type="tool_call",
) )
], ],
) )
@ -151,6 +152,7 @@ def test__convert_dict_to_message_tool_call() -> None:
args="oops", args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz", id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
type="invalid_tool_call",
) )
], ],
tool_calls=[ tool_calls=[
@ -158,6 +160,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername", name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"}, args={"name": "Sally", "hair_color": "green"},
id="call_abc123", id="call_abc123",
type="tool_call",
) )
], ],
) )
@ -353,7 +356,10 @@ def test_get_num_tokens_from_messages() -> None:
), ),
AIMessage("a nice bird"), AIMessage("a nice bird"),
AIMessage( AIMessage(
"", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})] "",
tool_calls=[
ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call")
],
), ),
AIMessage( AIMessage(
"", "",
@ -362,7 +368,10 @@ def test_get_num_tokens_from_messages() -> None:
}, },
), ),
AIMessage( AIMessage(
"text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})] "text",
tool_calls=[
ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call")
],
), ),
ToolMessage("foobar", tool_call_id="foo"), ToolMessage("foobar", tool_call_id="foo"),
] ]