mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
core[minor], integrations...[patch]: Support ToolCall as Tool input and ToolMessage as Tool output (#24038)
Changes: - ToolCall, InvalidToolCall and ToolCallChunk can all accept a "type" parameter now - LLM integration packages add "type" to all the above - Tool supports ToolCall inputs that have "type" specified - Tool outputs ToolMessage when a ToolCall is passed as input - Tools can separately specify ToolMessage.content and ToolMessage.raw_output - Tools emit events for validation errors (using on_tool_error and on_tool_end) Example: ```python @tool("structured_api", response_format="content_and_raw_output") def _mock_structured_tool_with_raw_output( arg1: int, arg2: bool, arg3: Optional[dict] = None ) -> Tuple[str, dict]: """A Structured Tool""" return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} def test_tool_call_input_tool_message_with_raw_output() -> None: tool_call: Dict = { "name": "structured_api", "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, "id": "123", "type": "tool_call", } expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123") tool = _mock_structured_tool_with_raw_output actual = tool.invoke(tool_call) assert actual == expected tool_call.pop("type") with pytest.raises(ValidationError): tool.invoke(tool_call) actual_content = tool.invoke(tool_call["args"]) assert actual_content == expected.content ``` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||
@@ -146,6 +146,11 @@ class ToolCall(TypedDict):
|
||||
An identifier is needed to associate a tool call request with a tool
|
||||
call result in events when multiple concurrent tool calls are made.
|
||||
"""
|
||||
type: NotRequired[Literal["tool_call"]]
|
||||
|
||||
|
||||
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
|
||||
return ToolCall(name=name, args=args, id=id, type="tool_call")
|
||||
|
||||
|
||||
class ToolCallChunk(TypedDict):
|
||||
@@ -176,6 +181,19 @@ class ToolCallChunk(TypedDict):
|
||||
"""An identifier associated with the tool call."""
|
||||
index: Optional[int]
|
||||
"""The index of the tool call in a sequence."""
|
||||
type: NotRequired[Literal["tool_call_chunk"]]
|
||||
|
||||
|
||||
def tool_call_chunk(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
index: Optional[int] = None,
|
||||
) -> ToolCallChunk:
|
||||
return ToolCallChunk(
|
||||
name=name, args=args, id=id, index=index, type="tool_call_chunk"
|
||||
)
|
||||
|
||||
|
||||
class InvalidToolCall(TypedDict):
|
||||
@@ -193,6 +211,19 @@ class InvalidToolCall(TypedDict):
|
||||
"""An identifier associated with the tool call."""
|
||||
error: Optional[str]
|
||||
"""An error message associated with the tool call."""
|
||||
type: NotRequired[Literal["invalid_tool_call"]]
|
||||
|
||||
|
||||
def invalid_tool_call(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> InvalidToolCall:
|
||||
return InvalidToolCall(
|
||||
name=name, args=args, id=id, error=error, type="invalid_tool_call"
|
||||
)
|
||||
|
||||
|
||||
def default_tool_parser(
|
||||
|
@@ -5,6 +5,12 @@ from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
from langchain_core.messages.tool import (
|
||||
invalid_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
)
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
@@ -59,6 +65,7 @@ def parse_tool_call(
|
||||
}
|
||||
if return_id:
|
||||
parsed["id"] = raw_tool_call.get("id")
|
||||
parsed = create_tool_call(**parsed) # type: ignore
|
||||
return parsed
|
||||
|
||||
|
||||
@@ -75,7 +82,7 @@ def make_invalid_tool_call(
|
||||
Returns:
|
||||
An InvalidToolCall instance with the error message.
|
||||
"""
|
||||
return InvalidToolCall(
|
||||
return invalid_tool_call(
|
||||
name=raw_tool_call["function"]["name"],
|
||||
args=raw_tool_call["function"]["arguments"],
|
||||
id=raw_tool_call.get("id"),
|
||||
|
@@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
import warnings
|
||||
@@ -34,6 +35,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@@ -42,7 +44,7 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, get_args, get_origin
|
||||
from typing_extensions import Annotated, cast, get_args, get_origin
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@@ -56,6 +58,7 @@ from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
PromptTemplate,
|
||||
@@ -306,7 +309,7 @@ class ToolException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
|
||||
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
@@ -378,6 +381,14 @@ class ChildTool(BaseTool):
|
||||
] = False
|
||||
"""Handle the content of the ValidationError thrown."""
|
||||
|
||||
response_format: Literal["content", "content_and_raw_output"] = "content"
|
||||
"""The tool response format.
|
||||
|
||||
If "content" then the output of the tool is interpreted as the contents of a
|
||||
ToolMessage. If "content_and_raw_output" then the output is expected to be a
|
||||
two-tuple corresponding to the (content, raw_output) of a ToolMessage.
|
||||
"""
|
||||
|
||||
class Config(Serializable.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@@ -410,46 +421,25 @@ class ChildTool(BaseTool):
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
input: Union[str, Dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
return self.run(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
config=config,
|
||||
**kwargs,
|
||||
)
|
||||
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||
return self.run(tool_input, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
input: Union[str, Dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
return await self.arun(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
config=config,
|
||||
**kwargs,
|
||||
)
|
||||
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||
return await self.arun(tool_input, **kwargs)
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
@@ -465,7 +455,7 @@ class ChildTool(BaseTool):
|
||||
for k, v in result.dict().items()
|
||||
if k in tool_input
|
||||
}
|
||||
return tool_input
|
||||
return tool_input
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
@@ -479,30 +469,27 @@ class ChildTool(BaseTool):
|
||||
return values
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
to child implementations to enable tracing.
|
||||
"""
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool asynchronously.
|
||||
|
||||
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
to child implementations to enable tracing.
|
||||
"""
|
||||
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
||||
"run_manager"
|
||||
):
|
||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
tool_input = self._parse_input(tool_input)
|
||||
# For backwards compatibility, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
if isinstance(tool_input, str):
|
||||
@@ -523,24 +510,20 @@ class ChildTool(BaseTool):
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
verbose_,
|
||||
self.verbose or bool(verbose),
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
# TODO: maybe also pass through run_manager is _run supports kwargs
|
||||
new_arg_supported = signature(self._run).parameters.get("run_manager")
|
||||
|
||||
run_manager = callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
@@ -550,67 +533,52 @@ class ChildTool(BaseTool):
|
||||
# Inputs by definition should always be dicts.
|
||||
# For now, it's unclear whether this assumption is ever violated,
|
||||
# but if it is we will send a `None` value to the callback instead
|
||||
# And will need to address issue via a patch.
|
||||
inputs=None if isinstance(tool_input, str) else tool_input,
|
||||
# TODO: will need to address issue via a patch.
|
||||
inputs=tool_input if isinstance(tool_input, dict) else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
content = None
|
||||
raw_output = None
|
||||
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
||||
try:
|
||||
child_config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
observation = (
|
||||
context.run(
|
||||
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else context.run(self._run, *tool_args, **tool_kwargs)
|
||||
)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
if signature(self._run).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||
if self.response_format == "content_and_raw_output":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise ValueError(
|
||||
"Since response_format='content_and_raw_output' "
|
||||
"a two-tuple of the message content and raw tool output is "
|
||||
f"expected. Instead generated response of type: "
|
||||
f"{type(response)}."
|
||||
)
|
||||
content, raw_output = response
|
||||
else:
|
||||
content = response
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
raise e
|
||||
elif isinstance(self.handle_validation_error, bool):
|
||||
observation = "Tool input validation error"
|
||||
elif isinstance(self.handle_validation_error, str):
|
||||
observation = self.handle_validation_error
|
||||
elif callable(self.handle_validation_error):
|
||||
observation = self.handle_validation_error(e)
|
||||
error_to_raise = e
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||
f"str or callable. Received: {self.handle_validation_error}"
|
||||
)
|
||||
return observation
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
error_to_raise = e
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
|
||||
return observation
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
|
||||
return observation
|
||||
error_to_raise = e
|
||||
|
||||
if error_to_raise:
|
||||
run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
output = _format_output(content, raw_output, tool_call_id)
|
||||
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
@@ -625,99 +593,80 @@ class ChildTool(BaseTool):
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
verbose_,
|
||||
self.verbose or bool(verbose),
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = signature(self._arun).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
inputs=tool_input,
|
||||
run_id=run_id,
|
||||
# Inputs by definition should always be dicts.
|
||||
# For now, it's unclear whether this assumption is ever violated,
|
||||
# but if it is we will send a `None` value to the callback instead
|
||||
# TODO: will need to address issue via a patch.
|
||||
inputs=tool_input if isinstance(tool_input, dict) else None,
|
||||
**kwargs,
|
||||
)
|
||||
content = None
|
||||
raw_output = None
|
||||
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
||||
try:
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
# We then call the tool on the tool input to get an observation
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
child_config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
coro = (
|
||||
context.run(
|
||||
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
)
|
||||
if self.__class__._arun is BaseTool._arun or signature(
|
||||
self._arun
|
||||
).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
if accepts_context(asyncio.create_task):
|
||||
observation = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
else:
|
||||
observation = await coro
|
||||
|
||||
response = await coro
|
||||
if self.response_format == "content_and_raw_output":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise ValueError(
|
||||
"Since response_format='content_and_raw_output' "
|
||||
"a two-tuple of the message content and raw tool output is "
|
||||
f"expected. Instead generated response of type: "
|
||||
f"{type(response)}."
|
||||
)
|
||||
content, raw_output = response
|
||||
else:
|
||||
content = response
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
raise e
|
||||
elif isinstance(self.handle_validation_error, bool):
|
||||
observation = "Tool input validation error"
|
||||
elif isinstance(self.handle_validation_error, str):
|
||||
observation = self.handle_validation_error
|
||||
elif callable(self.handle_validation_error):
|
||||
observation = self.handle_validation_error(e)
|
||||
error_to_raise = e
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||
f"str or callable. Received: {self.handle_validation_error}"
|
||||
)
|
||||
return observation
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
error_to_raise = e
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
await run_manager.on_tool_end(
|
||||
observation, color="red", name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_tool_end(
|
||||
observation, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
error_to_raise = e
|
||||
|
||||
if error_to_raise:
|
||||
await run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
|
||||
output = _format_output(content, raw_output, tool_call_id)
|
||||
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
@deprecated("0.1.47", alternative="invoke", removal="0.3.0")
|
||||
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
||||
@@ -738,7 +687,7 @@ class Tool(BaseTool):
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
input: Union[str, Dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@@ -780,17 +729,10 @@ class Tool(BaseTool):
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.func:
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support sync")
|
||||
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
return self.func(*args, **kwargs)
|
||||
raise NotImplementedError("Tool does not support sync invocation.")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
@@ -800,26 +742,13 @@ class Tool(BaseTool):
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
else:
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._run,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
|
||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||
# None.
|
||||
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
def __init__(
|
||||
@@ -870,9 +799,10 @@ class StructuredTool(BaseTool):
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
# TODO: Is this needed?
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
input: Union[str, Dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@@ -897,45 +827,26 @@ class StructuredTool(BaseTool):
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.func:
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support sync")
|
||||
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
return self.func(*args, **kwargs)
|
||||
raise NotImplementedError("StructuredTool does not support sync invocation.")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._run,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
|
||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||
# None.
|
||||
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
@@ -947,6 +858,8 @@ class StructuredTool(BaseTool):
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
*,
|
||||
response_format: Literal["content", "content_and_raw_output"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -963,6 +876,10 @@ class StructuredTool(BaseTool):
|
||||
return_direct: Whether to return the result directly or as a callback
|
||||
args_schema: The schema of the tool's input arguments
|
||||
infer_schema: Whether to infer the schema from the function's signature
|
||||
response_format: The tool response format. If "content" then the output of
|
||||
the tool is interpreted as the contents of a ToolMessage. If
|
||||
"content_and_raw_output" then the output is expected to be a two-tuple
|
||||
corresponding to the (content, raw_output) of a ToolMessage.
|
||||
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
|
||||
to parse parameter descriptions from Google Style function docstrings.
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||
@@ -1020,6 +937,7 @@ class StructuredTool(BaseTool):
|
||||
args_schema=_args_schema, # type: ignore[arg-type]
|
||||
description=description_,
|
||||
return_direct=return_direct,
|
||||
response_format=response_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1029,6 +947,7 @@ def tool(
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_raw_output"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
) -> Callable:
|
||||
@@ -1042,6 +961,10 @@ def tool(
|
||||
infer_schema: Whether to infer the schema of the arguments from
|
||||
the function's signature. This also makes the resultant tool
|
||||
accept a dictionary input to its `run()` function.
|
||||
response_format: The tool response format. If "content" then the output of
|
||||
the tool is interpreted as the contents of a ToolMessage. If
|
||||
"content_and_raw_output" then the output is expected to be a two-tuple
|
||||
corresponding to the (content, raw_output) of a ToolMessage.
|
||||
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
|
||||
parse parameter descriptions from Google Style function docstrings.
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||
@@ -1064,8 +987,12 @@ def tool(
|
||||
# Searches the API for the query.
|
||||
return
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
Parse Google-style docstrings:
|
||||
@tool(response_format="content_and_raw_output")
|
||||
def search_api(query: str) -> Tuple[str, dict]:
|
||||
return "partial json of results", {"full": "object of results"}
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
Parse Google-style docstrings:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -1179,6 +1106,7 @@ def tool(
|
||||
return_direct=return_direct,
|
||||
args_schema=schema,
|
||||
infer_schema=infer_schema,
|
||||
response_format=response_format,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
)
|
||||
@@ -1195,6 +1123,7 @@ def tool(
|
||||
description=f"{tool_name} tool",
|
||||
return_direct=return_direct,
|
||||
coroutine=coroutine,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
@@ -1350,6 +1279,103 @@ class BaseToolkit(BaseModel, ABC):
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
|
||||
def _is_tool_call(x: Any) -> bool:
|
||||
return isinstance(x, dict) and x.get("type") == "tool_call"
|
||||
|
||||
|
||||
def _handle_validation_error(
|
||||
e: ValidationError,
|
||||
*,
|
||||
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
content = "Tool input validation error"
|
||||
elif isinstance(flag, str):
|
||||
content = flag
|
||||
elif callable(flag):
|
||||
content = flag(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||
f"str or callable. Received: {flag}"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def _handle_tool_error(
|
||||
e: ToolException,
|
||||
*,
|
||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
if e.args:
|
||||
content = e.args[0]
|
||||
else:
|
||||
content = "Tool execution error"
|
||||
elif isinstance(flag, str):
|
||||
content = flag
|
||||
elif callable(flag):
|
||||
content = flag(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {flag}"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def _prep_run_args(
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Union[str, Dict], Dict]:
|
||||
config = ensure_config(config)
|
||||
if _is_tool_call(input):
|
||||
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
|
||||
tool_input: Union[str, dict] = cast(ToolCall, input)["args"]
|
||||
else:
|
||||
tool_call_id = None
|
||||
tool_input = cast(Union[str, dict], input)
|
||||
return (
|
||||
tool_input,
|
||||
dict(
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
config=config,
|
||||
tool_call_id=tool_call_id,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _format_output(
|
||||
content: Any, raw_output: Any, tool_call_id: Optional[str]
|
||||
) -> Union[ToolMessage, Any]:
|
||||
if tool_call_id:
|
||||
# NOTE: This will fail to stringify lists which aren't actually content blocks
|
||||
# but whose first element happens to be a string or dict. Tools should avoid
|
||||
# returning such contents.
|
||||
if not isinstance(content, str) and not (
|
||||
isinstance(content, list)
|
||||
and content
|
||||
and isinstance(content[0], (str, dict))
|
||||
):
|
||||
content = _stringify(content)
|
||||
return ToolMessage(content, raw_output=raw_output, tool_call_id=tool_call_id)
|
||||
else:
|
||||
return content
|
||||
|
||||
|
||||
def _stringify(content: Any) -> str:
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except Exception:
|
||||
return str(content)
|
||||
|
||||
|
||||
def _get_description_from_runnable(runnable: Runnable) -> str:
|
||||
"""Generate a placeholder description of a runnable."""
|
||||
input_schema = runnable.input_schema.schema()
|
||||
|
@@ -317,6 +317,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -419,6 +426,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -908,6 +922,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -1010,6 +1031,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
|
@@ -674,6 +674,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -776,6 +783,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
|
@@ -5577,6 +5577,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -5701,6 +5708,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -6237,6 +6251,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -6361,6 +6382,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -6834,6 +6862,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -6936,6 +6971,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -7444,6 +7486,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -7568,6 +7617,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -8068,6 +8124,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -8203,6 +8266,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -8683,6 +8753,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -8785,6 +8862,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -9238,6 +9322,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -9340,6 +9431,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -9880,6 +9978,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@@ -10004,6 +10109,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
|
@@ -8,7 +8,7 @@ import textwrap
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
@@ -17,6 +17,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
|
||||
from langchain_core.tools import (
|
||||
@@ -1067,6 +1068,65 @@ def test_tool_annotated_descriptions() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_tool_call_input_tool_message_output() -> None:
|
||||
tool_call = {
|
||||
"name": "structured_api",
|
||||
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
|
||||
"id": "123",
|
||||
"type": "tool_call",
|
||||
}
|
||||
tool = _MockStructuredTool()
|
||||
expected = ToolMessage("1 True {'img': 'base64string...'}", tool_call_id="123")
|
||||
actual = tool.invoke(tool_call)
|
||||
assert actual == expected
|
||||
|
||||
tool_call.pop("type")
|
||||
with pytest.raises(ValidationError):
|
||||
tool.invoke(tool_call)
|
||||
|
||||
|
||||
class _MockStructuredToolWithRawOutput(BaseTool):
|
||||
name: str = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
description: str = "A Structured Tool"
|
||||
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
|
||||
|
||||
def _run(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> Tuple[str, dict]:
|
||||
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||
|
||||
|
||||
@tool("structured_api", response_format="content_and_raw_output")
|
||||
def _mock_structured_tool_with_raw_output(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> Tuple[str, dict]:
|
||||
"""A Structured Tool"""
|
||||
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_raw_output]
|
||||
)
|
||||
def test_tool_call_input_tool_message_with_raw_output(tool: BaseTool) -> None:
|
||||
tool_call: Dict = {
|
||||
"name": "structured_api",
|
||||
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
|
||||
"id": "123",
|
||||
"type": "tool_call",
|
||||
}
|
||||
expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
|
||||
actual = tool.invoke(tool_call)
|
||||
assert actual == expected
|
||||
|
||||
tool_call.pop("type")
|
||||
with pytest.raises(ValidationError):
|
||||
tool.invoke(tool_call)
|
||||
|
||||
actual_content = tool.invoke(tool_call["args"])
|
||||
assert actual_content == expected.content
|
||||
|
||||
|
||||
def test_convert_from_runnable_dict() -> None:
|
||||
# Test with typed dict input
|
||||
class Args(TypedDict):
|
||||
|
Reference in New Issue
Block a user