mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 08:10:25 +00:00
Compare commits
3 Commits
langchain-
...
isaac/tool
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93580fb5ec | ||
|
|
f9cd109280 | ||
|
|
765350b630 |
@@ -90,26 +90,6 @@ def _get_annotation_description(arg_type: Type) -> str | None:
|
||||
return annotation
|
||||
return None
|
||||
|
||||
|
||||
def _get_filtered_args(
|
||||
inferred_model: Type[BaseModel],
|
||||
func: Callable,
|
||||
*,
|
||||
filter_args: Sequence[str],
|
||||
include_injected: bool = True,
|
||||
) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.model_json_schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {
|
||||
k: schema[k]
|
||||
for i, (k, param) in enumerate(valid_keys.items())
|
||||
if k not in filter_args
|
||||
and (i > 0 or param.name not in ("self", "cls"))
|
||||
and (include_injected or not _is_injected_arg_type(param.annotation))
|
||||
}
|
||||
|
||||
|
||||
def _parse_python_function_docstring(
|
||||
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
|
||||
) -> Tuple[str, dict]:
|
||||
@@ -649,43 +629,50 @@ class ChildTool(BaseTool):
|
||||
artifact = None
|
||||
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
if signature(self._run).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
if signature(self._run).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
|
||||
if config_param := _get_runnable_config_param(self._run):
|
||||
tool_kwargs[config_param] = config
|
||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||
if self.response_format == "content_and_artifact":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise ValueError(
|
||||
"Since response_format='content_and_artifact' "
|
||||
"a two-tuple of the message content and raw tool output is "
|
||||
f"expected. Instead generated response of type: "
|
||||
f"{type(response)}."
|
||||
if config_param := _get_runnable_config_param(self._run):
|
||||
tool_kwargs[config_param] = config
|
||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||
if self.response_format == "content_and_artifact":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise ValueError(
|
||||
"Since response_format='content_and_artifact' "
|
||||
"a two-tuple of the message content and raw tool output is "
|
||||
f"expected. Instead generated response of type: "
|
||||
f"{type(response)}."
|
||||
)
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(
|
||||
e, flag=self.handle_validation_error
|
||||
)
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ToolException(repr(e)) from e
|
||||
except KeyboardInterrupt as e:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
status = "error"
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
status = "error"
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
error_to_raise = e
|
||||
status = "error"
|
||||
|
||||
if error_to_raise:
|
||||
run_manager.on_tool_error(error_to_raise)
|
||||
@@ -758,50 +745,54 @@ class ChildTool(BaseTool):
|
||||
artifact = None
|
||||
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
||||
try:
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
func_to_check = (
|
||||
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
||||
)
|
||||
if signature(func_to_check).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
if config_param := _get_runnable_config_param(func_to_check):
|
||||
tool_kwargs[config_param] = config
|
||||
try:
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
func_to_check = (
|
||||
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
||||
)
|
||||
if signature(func_to_check).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
if config_param := _get_runnable_config_param(func_to_check):
|
||||
tool_kwargs[config_param] = config
|
||||
|
||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
if asyncio_accepts_context():
|
||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
else:
|
||||
response = await coro
|
||||
if self.response_format == "content_and_artifact":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise ValueError(
|
||||
"Since response_format='content_and_artifact' "
|
||||
"a two-tuple of the message content and raw tool output is "
|
||||
f"expected. Instead generated response of type: "
|
||||
f"{type(response)}."
|
||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
if asyncio_accepts_context():
|
||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
else:
|
||||
response = await coro
|
||||
if self.response_format == "content_and_artifact":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise ValueError(
|
||||
"Since response_format='content_and_artifact' "
|
||||
"a two-tuple of the message content and raw tool output is "
|
||||
f"expected. Instead generated response of type: "
|
||||
f"{type(response)}."
|
||||
)
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(
|
||||
e, flag=self.handle_validation_error
|
||||
)
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
status = "error"
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ToolException(repr(e)) from e
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
status = "error"
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
error_to_raise = e
|
||||
status = "error"
|
||||
|
||||
if error_to_raise:
|
||||
await run_manager.on_tool_error(error_to_raise)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools.base import BaseTool
|
||||
from langchain_core.tools.base import BaseTool, ToolException
|
||||
from langchain_core.tools.simple import Tool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
|
||||
@@ -18,6 +18,12 @@ def tool(
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
handle_tool_error: Optional[
|
||||
Union[bool, str, Callable[[ToolException], str]]
|
||||
] = False,
|
||||
handle_validation_error: Optional[
|
||||
Union[bool, str, Callable[[ValidationError], str]]
|
||||
] = False,
|
||||
) -> Callable:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
@@ -42,6 +48,14 @@ def tool(
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
||||
whether to raise ValueError on invalid Google Style docstrings.
|
||||
Defaults to True.
|
||||
handle_tool_error: Handle the content of the ToolException thrown. If False
|
||||
do nothing, If True returns 'Tool execution error'. If string then
|
||||
returns that string directly, if callable converts error to string to be returned.
|
||||
Defaults to False.
|
||||
handle_validation_error: Handle the content of the ValidationError thrown. If False
|
||||
do nothing, If True returns 'Tool input validation error'. If string then
|
||||
returns that string directly, if callable converts error to string to be returned.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
The tool.
|
||||
@@ -138,7 +152,7 @@ def tool(
|
||||
monkey: The baz.
|
||||
\"\"\"
|
||||
return bar
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
||||
@@ -185,6 +199,8 @@ def tool(
|
||||
response_format=response_format,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
handle_tool_error=handle_tool_error,
|
||||
handle_validation_error=handle_validation_error,
|
||||
)
|
||||
# If someone doesn't want a schema applied, we must treat it as
|
||||
# a simple string->string function
|
||||
@@ -200,6 +216,8 @@ def tool(
|
||||
return_direct=return_direct,
|
||||
coroutine=coroutine,
|
||||
response_format=response_format,
|
||||
handle_tool_error=handle_tool_error,
|
||||
handle_validation_error=handle_validation_error,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
||||
@@ -168,6 +168,54 @@ def test_subclass_annotated_base_tool_accepted() -> None:
|
||||
assert tool.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorator_with_error_handling() -> None:
|
||||
"""Test that error handling works when passed through decorator."""
|
||||
|
||||
@tool()
|
||||
def tool_func(arg1: int, arg2: int) -> float:
|
||||
"""foo bar tool"""
|
||||
return arg1 / arg2
|
||||
|
||||
with pytest.raises(ToolException):
|
||||
tool_func.invoke({"arg1": 1, "arg2": 0})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
tool_func.invoke({"arg1": 1})
|
||||
|
||||
@tool(handle_tool_error="foo", handle_validation_error="bar")
|
||||
def tool_func_2(arg1: int, arg2: int) -> float:
|
||||
"""foo bar tool"""
|
||||
return arg1 / arg2
|
||||
|
||||
tool_exception = tool_func_2.invoke({"arg1": 1, "arg2": 0})
|
||||
assert tool_exception == "foo"
|
||||
|
||||
validation_error = tool_func_2.invoke({"arg1": 1})
|
||||
assert validation_error == "bar"
|
||||
|
||||
@tool(handle_tool_error=lambda e: "foo", handle_validation_error=lambda e: "bar")
|
||||
def tool_func_3(arg1: int, arg2: int) -> float:
|
||||
"""foo bar tool"""
|
||||
return arg1 / arg2
|
||||
|
||||
tool_exception = tool_func_3.invoke({"arg1": 1, "arg2": 0})
|
||||
assert tool_exception == "foo"
|
||||
|
||||
validation_error = tool_func_3.invoke({"arg1": 1})
|
||||
assert validation_error == "bar"
|
||||
|
||||
@tool(handle_tool_error=True, handle_validation_error=True)
|
||||
def tool_func_4(arg1: int, arg2: int) -> float:
|
||||
"""foo bar tool"""
|
||||
return arg1 / arg2
|
||||
|
||||
tool_exception = tool_func_4.invoke({"arg1": 1, "arg2": 0})
|
||||
assert tool_exception == "division by zero"
|
||||
|
||||
validation_error = tool_func_4.invoke({"arg1": 1})
|
||||
assert validation_error == "Tool input validation error"
|
||||
|
||||
|
||||
def test_decorator_with_specified_schema() -> None:
|
||||
"""Test that manually specified schemata are passed through to the tool."""
|
||||
|
||||
@@ -740,7 +788,7 @@ def test_exception_handling_callable() -> None:
|
||||
|
||||
def test_exception_handling_non_tool_exception() -> None:
|
||||
_tool = _FakeExceptionTool(exception=ValueError())
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ToolException):
|
||||
_tool.run({})
|
||||
|
||||
|
||||
@@ -771,7 +819,7 @@ async def test_async_exception_handling_callable() -> None:
|
||||
|
||||
async def test_async_exception_handling_non_tool_exception() -> None:
|
||||
_tool = _FakeExceptionTool(exception=ValueError())
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ToolException):
|
||||
await _tool.arun({})
|
||||
|
||||
|
||||
@@ -867,7 +915,7 @@ def test_validation_error_handling_non_validation_error(
|
||||
return "dummy"
|
||||
|
||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
|
||||
with pytest.raises(NotImplementedError):
|
||||
with pytest.raises(ToolException):
|
||||
_tool.run({})
|
||||
|
||||
|
||||
@@ -929,7 +977,7 @@ async def test_async_validation_error_handling_non_validation_error(
|
||||
return "dummy"
|
||||
|
||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
|
||||
with pytest.raises(NotImplementedError):
|
||||
with pytest.raises(ToolException):
|
||||
await _tool.arun({})
|
||||
|
||||
|
||||
@@ -1470,10 +1518,7 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
|
||||
assert tool_.invoke(
|
||||
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
|
||||
) == ToolMessage("bar", tool_call_id="123", name="foo")
|
||||
expected_error = (
|
||||
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
|
||||
)
|
||||
with pytest.raises(expected_error):
|
||||
with pytest.raises(ToolException):
|
||||
tool_.invoke({"x": 5})
|
||||
|
||||
assert convert_to_openai_function(tool_) == {
|
||||
@@ -2092,5 +2137,5 @@ def test_structured_tool_direct_init() -> None:
|
||||
|
||||
tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
with pytest.raises(ToolException):
|
||||
assert tool.invoke("hello") == "hello"
|
||||
|
||||
Reference in New Issue
Block a user