mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-28 06:48:50 +00:00
catch exceptions properly
This commit is contained in:
@@ -239,15 +239,6 @@ class ToolException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def convert_exception_to_tool_exception(e: Exception) -> ToolException:
|
|
||||||
tool_e = ToolException()
|
|
||||||
tool_e.args = e.args
|
|
||||||
tool_e.__cause__ = e.__cause__
|
|
||||||
tool_e.__context__ = e.__context__
|
|
||||||
tool_e.__traceback__ = e.__traceback__
|
|
||||||
return tool_e
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
|
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
|
||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
@@ -552,44 +543,54 @@ class ChildTool(BaseTool):
|
|||||||
artifact = None
|
artifact = None
|
||||||
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
||||||
try:
|
try:
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
try:
|
||||||
context = copy_context()
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context.run(_set_config_context, child_config)
|
context = copy_context()
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
context.run(_set_config_context, child_config)
|
||||||
if signature(self._run).parameters.get("run_manager"):
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||||
tool_kwargs["run_manager"] = run_manager
|
if signature(self._run).parameters.get("run_manager"):
|
||||||
|
tool_kwargs["run_manager"] = run_manager
|
||||||
|
|
||||||
if config_param := _get_runnable_config_param(self._run):
|
if config_param := _get_runnable_config_param(self._run):
|
||||||
tool_kwargs[config_param] = config
|
tool_kwargs[config_param] = config
|
||||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||||
if self.response_format == "content_and_artifact":
|
if self.response_format == "content_and_artifact":
|
||||||
if not isinstance(response, tuple) or len(response) != 2:
|
if not isinstance(response, tuple) or len(response) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Since response_format='content_and_artifact' "
|
"Since response_format='content_and_artifact' "
|
||||||
"a two-tuple of the message content and raw tool output is "
|
"a two-tuple of the message content and raw tool output is "
|
||||||
f"expected. Instead generated response of type: "
|
f"expected. Instead generated response of type: "
|
||||||
f"{type(response)}."
|
f"{type(response)}."
|
||||||
|
)
|
||||||
|
content, artifact = response
|
||||||
|
else:
|
||||||
|
content = response
|
||||||
|
status = "success"
|
||||||
|
except ValidationError as e:
|
||||||
|
if not self.handle_validation_error:
|
||||||
|
error_to_raise = e
|
||||||
|
else:
|
||||||
|
content = _handle_validation_error(
|
||||||
|
e, flag=self.handle_validation_error
|
||||||
)
|
)
|
||||||
content, artifact = response
|
status = "error"
|
||||||
else:
|
except ToolException as e:
|
||||||
content = response
|
if not self.handle_tool_error:
|
||||||
status = "success"
|
error_to_raise = e
|
||||||
except ValidationError as e:
|
else:
|
||||||
if not self.handle_validation_error:
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
|
status = "error"
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolException(str(e)) from e
|
||||||
|
except KeyboardInterrupt as e:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
status = "error"
|
||||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
except ToolException as e:
|
||||||
status = "error"
|
|
||||||
except Exception as e:
|
|
||||||
if not self.handle_tool_error:
|
if not self.handle_tool_error:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
else:
|
||||||
e = convert_exception_to_tool_exception(e)
|
|
||||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
status = "error"
|
status = "error"
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
error_to_raise = e
|
|
||||||
status = "error"
|
|
||||||
|
|
||||||
if error_to_raise:
|
if error_to_raise:
|
||||||
run_manager.on_tool_error(error_to_raise)
|
run_manager.on_tool_error(error_to_raise)
|
||||||
@@ -662,51 +663,61 @@ class ChildTool(BaseTool):
|
|||||||
artifact = None
|
artifact = None
|
||||||
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
||||||
try:
|
try:
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
try:
|
||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||||
context = copy_context()
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context.run(_set_config_context, child_config)
|
context = copy_context()
|
||||||
func_to_check = (
|
context.run(_set_config_context, child_config)
|
||||||
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
func_to_check = (
|
||||||
)
|
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
||||||
if signature(func_to_check).parameters.get("run_manager"):
|
)
|
||||||
tool_kwargs["run_manager"] = run_manager
|
if signature(func_to_check).parameters.get("run_manager"):
|
||||||
if config_param := _get_runnable_config_param(func_to_check):
|
tool_kwargs["run_manager"] = run_manager
|
||||||
tool_kwargs[config_param] = config
|
if config_param := _get_runnable_config_param(func_to_check):
|
||||||
|
tool_kwargs[config_param] = config
|
||||||
|
|
||||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||||
if asyncio_accepts_context():
|
if asyncio_accepts_context():
|
||||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
else:
|
else:
|
||||||
response = await coro
|
response = await coro
|
||||||
if self.response_format == "content_and_artifact":
|
if self.response_format == "content_and_artifact":
|
||||||
if not isinstance(response, tuple) or len(response) != 2:
|
if not isinstance(response, tuple) or len(response) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Since response_format='content_and_artifact' "
|
"Since response_format='content_and_artifact' "
|
||||||
"a two-tuple of the message content and raw tool output is "
|
"a two-tuple of the message content and raw tool output is "
|
||||||
f"expected. Instead generated response of type: "
|
f"expected. Instead generated response of type: "
|
||||||
f"{type(response)}."
|
f"{type(response)}."
|
||||||
|
)
|
||||||
|
content, artifact = response
|
||||||
|
else:
|
||||||
|
content = response
|
||||||
|
status = "success"
|
||||||
|
except ValidationError as e:
|
||||||
|
if not self.handle_validation_error:
|
||||||
|
error_to_raise = e
|
||||||
|
else:
|
||||||
|
content = _handle_validation_error(
|
||||||
|
e, flag=self.handle_validation_error
|
||||||
)
|
)
|
||||||
content, artifact = response
|
status = "error"
|
||||||
else:
|
except ToolException as e:
|
||||||
content = response
|
if not self.handle_tool_error:
|
||||||
status = "success"
|
error_to_raise = e
|
||||||
except ValidationError as e:
|
else:
|
||||||
if not self.handle_validation_error:
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
|
status = "error"
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolException(str(e)) from e
|
||||||
|
except KeyboardInterrupt as e:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
status = "error"
|
||||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
except ToolException as e:
|
||||||
status = "error"
|
|
||||||
except Exception as e:
|
|
||||||
if not self.handle_tool_error:
|
if not self.handle_tool_error:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
else:
|
||||||
e = convert_exception_to_tool_exception(e)
|
|
||||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
status = "error"
|
status = "error"
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
error_to_raise = e
|
|
||||||
status = "error"
|
|
||||||
|
|
||||||
if error_to_raise:
|
if error_to_raise:
|
||||||
await run_manager.on_tool_error(error_to_raise)
|
await run_manager.on_tool_error(error_to_raise)
|
||||||
|
|||||||
@@ -679,7 +679,7 @@ def test_exception_handling_callable() -> None:
|
|||||||
|
|
||||||
def test_exception_handling_non_tool_exception() -> None:
|
def test_exception_handling_non_tool_exception() -> None:
|
||||||
_tool = _FakeExceptionTool(exception=ValueError())
|
_tool = _FakeExceptionTool(exception=ValueError())
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ToolException):
|
||||||
_tool.run({})
|
_tool.run({})
|
||||||
|
|
||||||
|
|
||||||
@@ -710,7 +710,7 @@ async def test_async_exception_handling_callable() -> None:
|
|||||||
|
|
||||||
async def test_async_exception_handling_non_tool_exception() -> None:
|
async def test_async_exception_handling_non_tool_exception() -> None:
|
||||||
_tool = _FakeExceptionTool(exception=ValueError())
|
_tool = _FakeExceptionTool(exception=ValueError())
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ToolException):
|
||||||
await _tool.arun({})
|
await _tool.arun({})
|
||||||
|
|
||||||
|
|
||||||
@@ -806,7 +806,7 @@ def test_validation_error_handling_non_validation_error(
|
|||||||
return "dummy"
|
return "dummy"
|
||||||
|
|
||||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
|
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(ToolException):
|
||||||
_tool.run({})
|
_tool.run({})
|
||||||
|
|
||||||
|
|
||||||
@@ -868,7 +868,7 @@ async def test_async_validation_error_handling_non_validation_error(
|
|||||||
return "dummy"
|
return "dummy"
|
||||||
|
|
||||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
|
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(ToolException):
|
||||||
await _tool.arun({})
|
await _tool.arun({})
|
||||||
|
|
||||||
|
|
||||||
@@ -1408,10 +1408,8 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
|
|||||||
assert tool_.invoke(
|
assert tool_.invoke(
|
||||||
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
|
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
|
||||||
) == ToolMessage("bar", tool_call_id="123", name="foo")
|
) == ToolMessage("bar", tool_call_id="123", name="foo")
|
||||||
expected_error = (
|
|
||||||
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
|
with pytest.raises(ToolException):
|
||||||
)
|
|
||||||
with pytest.raises(expected_error):
|
|
||||||
tool_.invoke({"x": 5})
|
tool_.invoke({"x": 5})
|
||||||
|
|
||||||
assert convert_to_openai_function(tool_) == {
|
assert convert_to_openai_function(tool_) == {
|
||||||
@@ -1873,7 +1871,6 @@ def test__get_all_basemodel_annotations_v1() -> None:
|
|||||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||||
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||||
from pydantic import Field as FieldV2 # pydantic: ignore
|
from pydantic import Field as FieldV2 # pydantic: ignore
|
||||||
from pydantic import ValidationError as ValidationErrorV2 # pydantic: ignore
|
|
||||||
|
|
||||||
class Foo(BaseModelV2):
|
class Foo(BaseModelV2):
|
||||||
x: List[int] = FieldV2(
|
x: List[int] = FieldV2(
|
||||||
@@ -1903,7 +1900,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert foo.invoke({"x": [0] * 10})
|
assert foo.invoke({"x": [0] * 10})
|
||||||
with pytest.raises(ValidationErrorV2):
|
with pytest.raises(ToolException):
|
||||||
foo.invoke({"x": [0] * 9})
|
foo.invoke({"x": [0] * 9})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user