mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-02 19:34:04 +00:00
parent
5d0bea8378
commit
2842e0c8c1
@ -1,4 +1,4 @@
|
||||
"""Base for Tools."""
|
||||
"""Base classes and utilities for LangChain tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -77,14 +77,30 @@ FILTERED_ARGS = ("run_manager", "callbacks")
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
"""Raised when args_schema is missing or has an incorrect type annotation."""
|
||||
|
||||
|
||||
def _is_annotated_type(typ: type[Any]) -> bool:
|
||||
"""Check if a type is an Annotated type.
|
||||
|
||||
Args:
|
||||
typ: The type to check.
|
||||
|
||||
Returns:
|
||||
True if the type is an Annotated type, False otherwise.
|
||||
"""
|
||||
return get_origin(typ) is typing.Annotated
|
||||
|
||||
|
||||
def _get_annotation_description(arg_type: type) -> str | None:
|
||||
"""Extract description from an Annotated type.
|
||||
|
||||
Args:
|
||||
arg_type: The type to extract description from.
|
||||
|
||||
Returns:
|
||||
The description string if found, None otherwise.
|
||||
"""
|
||||
if _is_annotated_type(arg_type):
|
||||
annotated_args = get_args(arg_type)
|
||||
for annotation in annotated_args[1:]:
|
||||
@ -100,7 +116,17 @@ def _get_filtered_args(
|
||||
filter_args: Sequence[str],
|
||||
include_injected: bool = True,
|
||||
) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
"""Get filtered arguments from a function's signature.
|
||||
|
||||
Args:
|
||||
inferred_model: The Pydantic model inferred from the function.
|
||||
func: The function to extract arguments from.
|
||||
filter_args: Arguments to exclude from the result.
|
||||
include_injected: Whether to include injected arguments.
|
||||
|
||||
Returns:
|
||||
Dictionary of filtered arguments with their schema definitions.
|
||||
"""
|
||||
schema = inferred_model.model_json_schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {
|
||||
@ -115,9 +141,17 @@ def _get_filtered_args(
|
||||
def _parse_python_function_docstring(
|
||||
function: Callable, annotations: dict, *, error_on_invalid_docstring: bool = False
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
"""Parse function and argument descriptions from a docstring.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
|
||||
Args:
|
||||
function: The function to parse the docstring from.
|
||||
annotations: Type annotations for the function parameters.
|
||||
error_on_invalid_docstring: Whether to raise an error on invalid docstring.
|
||||
|
||||
Returns:
|
||||
A tuple containing the function description and argument descriptions.
|
||||
"""
|
||||
docstring = inspect.getdoc(function)
|
||||
return _parse_google_docstring(
|
||||
@ -130,7 +164,15 @@ def _parse_python_function_docstring(
|
||||
def _validate_docstring_args_against_annotations(
|
||||
arg_descriptions: dict, annotations: dict
|
||||
) -> None:
|
||||
"""Raise error if docstring arg is not in type annotations."""
|
||||
"""Validate that docstring arguments match function annotations.
|
||||
|
||||
Args:
|
||||
arg_descriptions: Arguments described in the docstring.
|
||||
annotations: Type annotations from the function signature.
|
||||
|
||||
Raises:
|
||||
ValueError: If a docstring argument is not found in function signature.
|
||||
"""
|
||||
for docstring_arg in arg_descriptions:
|
||||
if docstring_arg not in annotations:
|
||||
msg = f"Arg {docstring_arg} in docstring not found in function signature."
|
||||
@ -143,7 +185,16 @@ def _infer_arg_descriptions(
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> tuple[str, dict]:
|
||||
"""Infer argument descriptions from a function's docstring."""
|
||||
"""Infer argument descriptions from function docstring and annotations.
|
||||
|
||||
Args:
|
||||
fn: The function to infer descriptions from.
|
||||
parse_docstring: Whether to parse the docstring for descriptions.
|
||||
error_on_invalid_docstring: Whether to raise error on invalid docstring.
|
||||
|
||||
Returns:
|
||||
A tuple containing the function description and argument descriptions.
|
||||
"""
|
||||
annotations = typing.get_type_hints(fn, include_extras=True)
|
||||
if parse_docstring:
|
||||
description, arg_descriptions = _parse_python_function_docstring(
|
||||
@ -163,7 +214,15 @@ def _infer_arg_descriptions(
|
||||
|
||||
|
||||
def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool:
|
||||
"""Determine if a type annotation is a Pydantic model."""
|
||||
"""Check if a type annotation is a Pydantic model.
|
||||
|
||||
Args:
|
||||
annotation: The type annotation to check.
|
||||
pydantic_version: The Pydantic version to check against ("v1" or "v2").
|
||||
|
||||
Returns:
|
||||
True if the annotation is a Pydantic model, False otherwise.
|
||||
"""
|
||||
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
|
||||
try:
|
||||
return issubclass(annotation, base_model_class)
|
||||
@ -174,7 +233,18 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
|
||||
def _function_annotations_are_pydantic_v1(
|
||||
signature: inspect.Signature, func: Callable
|
||||
) -> bool:
|
||||
"""Determine if all Pydantic annotations in a function signature are from V1."""
|
||||
"""Check if all Pydantic annotations in a function are from V1.
|
||||
|
||||
Args:
|
||||
signature: The function signature to check.
|
||||
func: The function being checked.
|
||||
|
||||
Returns:
|
||||
True if all Pydantic annotations are from V1, False otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the function contains mixed V1 and V2 annotations.
|
||||
"""
|
||||
any_v1_annotations = any(
|
||||
_is_pydantic_annotation(parameter.annotation, pydantic_version="v1")
|
||||
for parameter in signature.parameters.values()
|
||||
@ -193,15 +263,11 @@ def _function_annotations_are_pydantic_v1(
|
||||
|
||||
|
||||
class _SchemaConfig:
|
||||
"""Configuration for the pydantic model.
|
||||
"""Configuration for Pydantic models generated from function signatures.
|
||||
|
||||
This is used to configure the pydantic model created from
|
||||
a function's signature.
|
||||
|
||||
Parameters:
|
||||
Attributes:
|
||||
extra: Whether to allow extra fields in the model.
|
||||
arbitrary_types_allowed: Whether to allow arbitrary types in the model.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
extra: str = "forbid"
|
||||
@ -309,12 +375,11 @@ def create_schema_from_function(
|
||||
|
||||
|
||||
class ToolException(Exception): # noqa: N818
|
||||
"""Optional exception that tool throws when execution error occurs.
|
||||
"""Exception thrown when a tool execution error occurs.
|
||||
|
||||
When this exception is thrown, the agent will not stop working,
|
||||
but it will handle the exception according to the handle_tool_error
|
||||
variable of the tool, and the processing result will be returned
|
||||
to the agent as observation, and printed in red on the console.
|
||||
This exception allows tools to signal errors without stopping the agent.
|
||||
The error is handled according to the tool's handle_tool_error setting,
|
||||
and the result is returned as an observation to the agent.
|
||||
"""
|
||||
|
||||
|
||||
@ -322,10 +387,21 @@ ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
"""Base class for all LangChain tools.
|
||||
|
||||
This abstract class defines the interface that all LangChain tools must implement.
|
||||
Tools are components that can be called by agents to perform specific actions.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
"""Create the definition of the new tool class."""
|
||||
"""Validate the tool class definition during subclass creation.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments passed to the parent class.
|
||||
|
||||
Raises:
|
||||
SchemaAnnotationError: If args_schema has incorrect type annotation.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
args_schema_type = cls.__annotations__.get("args_schema", None)
|
||||
@ -444,13 +520,21 @@ class ChildTool(BaseTool):
|
||||
|
||||
@property
|
||||
def is_single_input(self) -> bool:
|
||||
"""Whether the tool only accepts a single input."""
|
||||
"""Check if the tool accepts only a single input argument.
|
||||
|
||||
Returns:
|
||||
True if the tool has only one input argument, False otherwise.
|
||||
"""
|
||||
keys = {k for k in self.args if k != "kwargs"}
|
||||
return len(keys) == 1
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""The arguments of the tool."""
|
||||
"""Get the tool's input arguments schema.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the tool's argument properties.
|
||||
"""
|
||||
if isinstance(self.args_schema, dict):
|
||||
json_schema = self.args_schema
|
||||
else:
|
||||
@ -460,7 +544,11 @@ class ChildTool(BaseTool):
|
||||
|
||||
@property
|
||||
def tool_call_schema(self) -> ArgsSchema:
|
||||
"""The schema for a tool call."""
|
||||
"""Get the schema for tool calls, excluding injected arguments.
|
||||
|
||||
Returns:
|
||||
The schema that should be used for tool calls from language models.
|
||||
"""
|
||||
if isinstance(self.args_schema, dict):
|
||||
if self.description:
|
||||
return {
|
||||
@ -524,11 +612,19 @@ class ChildTool(BaseTool):
|
||||
def _parse_input(
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
"""Convert tool input to a pydantic model.
|
||||
"""Parse and validate tool input using the args schema.
|
||||
|
||||
Args:
|
||||
tool_input: The input to the tool.
|
||||
tool_call_id: The id of the tool call.
|
||||
tool_input: The raw input to the tool.
|
||||
tool_call_id: The ID of the tool call, if available.
|
||||
|
||||
Returns:
|
||||
The parsed and validated input.
|
||||
|
||||
Raises:
|
||||
ValueError: If string input is provided with JSON schema or if
|
||||
InjectedToolCallId is required but not provided.
|
||||
NotImplementedError: If args_schema is not a supported type.
|
||||
"""
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
@ -640,6 +736,18 @@ class ChildTool(BaseTool):
|
||||
def _to_args_and_kwargs(
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
) -> tuple[tuple, dict]:
|
||||
"""Convert tool input to positional and keyword arguments.
|
||||
|
||||
Args:
|
||||
tool_input: The input to the tool.
|
||||
tool_call_id: The ID of the tool call, if available.
|
||||
|
||||
Returns:
|
||||
A tuple of (positional_args, keyword_args) for the tool.
|
||||
|
||||
Raises:
|
||||
TypeError: If the tool input type is invalid.
|
||||
"""
|
||||
if (
|
||||
self.args_schema is not None
|
||||
and isinstance(self.args_schema, type)
|
||||
@ -892,11 +1000,27 @@ class ChildTool(BaseTool):
|
||||
|
||||
@deprecated("0.1.47", alternative="invoke", removal="1.0")
|
||||
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
||||
"""Make tool callable."""
|
||||
"""Make tool callable (deprecated).
|
||||
|
||||
Args:
|
||||
tool_input: The input to the tool.
|
||||
callbacks: Callbacks to use during execution.
|
||||
|
||||
Returns:
|
||||
The tool's output.
|
||||
"""
|
||||
return self.run(tool_input, callbacks=callbacks)
|
||||
|
||||
|
||||
def _is_tool_call(x: Any) -> bool:
|
||||
"""Check if the input is a tool call dictionary.
|
||||
|
||||
Args:
|
||||
x: The input to check.
|
||||
|
||||
Returns:
|
||||
True if the input is a tool call, False otherwise.
|
||||
"""
|
||||
return isinstance(x, dict) and x.get("type") == "tool_call"
|
||||
|
||||
|
||||
@ -907,6 +1031,18 @@ def _handle_validation_error(
|
||||
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
||||
],
|
||||
) -> str:
|
||||
"""Handle validation errors based on the configured flag.
|
||||
|
||||
Args:
|
||||
e: The validation error that occurred.
|
||||
flag: How to handle the error (bool, string, or callable).
|
||||
|
||||
Returns:
|
||||
The error message to return.
|
||||
|
||||
Raises:
|
||||
ValueError: If the flag type is unexpected.
|
||||
"""
|
||||
if isinstance(flag, bool):
|
||||
content = "Tool input validation error"
|
||||
elif isinstance(flag, str):
|
||||
@ -927,6 +1063,18 @@ def _handle_tool_error(
|
||||
*,
|
||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||
) -> str:
|
||||
"""Handle tool execution errors based on the configured flag.
|
||||
|
||||
Args:
|
||||
e: The tool exception that occurred.
|
||||
flag: How to handle the error (bool, string, or callable).
|
||||
|
||||
Returns:
|
||||
The error message to return.
|
||||
|
||||
Raises:
|
||||
ValueError: If the flag type is unexpected.
|
||||
"""
|
||||
if isinstance(flag, bool):
|
||||
content = e.args[0] if e.args else "Tool execution error"
|
||||
elif isinstance(flag, str):
|
||||
@ -947,6 +1095,16 @@ def _prep_run_args(
|
||||
config: Optional[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> tuple[Union[str, dict], dict]:
|
||||
"""Prepare arguments for tool execution.
|
||||
|
||||
Args:
|
||||
value: The input value (string, dict, or ToolCall).
|
||||
config: The runnable configuration.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A tuple of (tool_input, run_kwargs).
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
if _is_tool_call(value):
|
||||
tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
|
||||
@ -976,6 +1134,18 @@ def _format_output(
|
||||
name: str,
|
||||
status: str,
|
||||
) -> Union[ToolOutputMixin, Any]:
|
||||
"""Format tool output as a ToolMessage if appropriate.
|
||||
|
||||
Args:
|
||||
content: The main content of the tool output.
|
||||
artifact: Any artifact data from the tool.
|
||||
tool_call_id: The ID of the tool call.
|
||||
name: The name of the tool.
|
||||
status: The execution status.
|
||||
|
||||
Returns:
|
||||
The formatted output, either as a ToolMessage or the original content.
|
||||
"""
|
||||
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
|
||||
return content
|
||||
if not _is_message_content_type(content):
|
||||
@ -990,14 +1160,32 @@ def _format_output(
|
||||
|
||||
|
||||
def _is_message_content_type(obj: Any) -> bool:
|
||||
"""Check for OpenAI or Anthropic format tool message content."""
|
||||
"""Check if object is valid message content format.
|
||||
|
||||
Validates content for OpenAI or Anthropic format tool messages.
|
||||
|
||||
Args:
|
||||
obj: The object to check.
|
||||
|
||||
Returns:
|
||||
True if the object is valid message content, False otherwise.
|
||||
"""
|
||||
return isinstance(obj, str) or (
|
||||
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
|
||||
)
|
||||
|
||||
|
||||
def _is_message_content_block(obj: Any) -> bool:
|
||||
"""Check for OpenAI or Anthropic format tool message content blocks."""
|
||||
"""Check if object is a valid message content block.
|
||||
|
||||
Validates content blocks for OpenAI or Anthropic format.
|
||||
|
||||
Args:
|
||||
obj: The object to check.
|
||||
|
||||
Returns:
|
||||
True if the object is a valid content block, False otherwise.
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return True
|
||||
if isinstance(obj, dict):
|
||||
@ -1006,6 +1194,14 @@ def _is_message_content_block(obj: Any) -> bool:
|
||||
|
||||
|
||||
def _stringify(content: Any) -> str:
|
||||
"""Convert content to string, preferring JSON format.
|
||||
|
||||
Args:
|
||||
content: The content to stringify.
|
||||
|
||||
Returns:
|
||||
String representation of the content.
|
||||
"""
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except Exception:
|
||||
@ -1013,6 +1209,14 @@ def _stringify(content: Any) -> str:
|
||||
|
||||
|
||||
def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
|
||||
"""Get type hints from a function, handling partial functions.
|
||||
|
||||
Args:
|
||||
func: The function to get type hints from.
|
||||
|
||||
Returns:
|
||||
Dictionary of type hints, or None if extraction fails.
|
||||
"""
|
||||
if isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
try:
|
||||
@ -1022,6 +1226,14 @@ def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
|
||||
|
||||
|
||||
def _get_runnable_config_param(func: Callable) -> Optional[str]:
|
||||
"""Find the parameter name for RunnableConfig in a function.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
The parameter name for RunnableConfig, or None if not found.
|
||||
"""
|
||||
type_hints = _get_type_hints(func)
|
||||
if not type_hints:
|
||||
return None
|
||||
@ -1032,30 +1244,52 @@ def _get_runnable_config_param(func: Callable) -> Optional[str]:
|
||||
|
||||
|
||||
class InjectedToolArg:
|
||||
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
|
||||
"""Annotation for tool arguments that are injected at runtime.
|
||||
|
||||
Tool arguments annotated with this class are not included in the tool
|
||||
schema sent to language models and are instead injected during execution.
|
||||
"""
|
||||
|
||||
|
||||
class InjectedToolCallId(InjectedToolArg):
|
||||
"""Annotation for injecting the tool_call_id.
|
||||
"""Annotation for injecting the tool call ID.
|
||||
|
||||
This annotation is used to mark a tool parameter that should receive
|
||||
the tool call ID at runtime.
|
||||
|
||||
Example:
|
||||
..code-block:: python
|
||||
```python
|
||||
from typing_extensions import Annotated
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool, InjectedToolCallId
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool, InjectedToolCallID
|
||||
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallID]) -> ToolMessage:
|
||||
\"\"\"Return x.\"\"\"
|
||||
return ToolMessage(str(x), artifact=x, name="foo", tool_call_id=tool_call_id)
|
||||
""" # noqa: E501
|
||||
@tool
|
||||
def foo(
|
||||
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> ToolMessage:
|
||||
\"\"\"Return x.\"\"\"
|
||||
return ToolMessage(
|
||||
str(x),
|
||||
artifact=x,
|
||||
name="foo",
|
||||
tool_call_id=tool_call_id
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def _is_injected_arg_type(
|
||||
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
|
||||
) -> bool:
|
||||
"""Check if a type annotation indicates an injected argument.
|
||||
|
||||
Args:
|
||||
type_: The type annotation to check.
|
||||
injected_type: The specific injected type to check for.
|
||||
|
||||
Returns:
|
||||
True if the type is an injected argument, False otherwise.
|
||||
"""
|
||||
injected_type = injected_type or InjectedToolArg
|
||||
return any(
|
||||
isinstance(arg, injected_type)
|
||||
@ -1138,6 +1372,16 @@ def _replace_type_vars(
|
||||
*,
|
||||
default_to_bound: bool = True,
|
||||
) -> type:
|
||||
"""Replace TypeVars in a type annotation with concrete types.
|
||||
|
||||
Args:
|
||||
type_: The type annotation to process.
|
||||
generic_map: Mapping of TypeVars to concrete types.
|
||||
default_to_bound: Whether to use TypeVar bounds as defaults.
|
||||
|
||||
Returns:
|
||||
The type with TypeVars replaced.
|
||||
"""
|
||||
generic_map = generic_map or {}
|
||||
if isinstance(type_, TypeVar):
|
||||
if type_ in generic_map:
|
||||
@ -1155,8 +1399,16 @@ def _replace_type_vars(
|
||||
|
||||
|
||||
class BaseToolkit(BaseModel, ABC):
|
||||
"""Base Toolkit representing a collection of related tools."""
|
||||
"""Base class for toolkits containing related tools.
|
||||
|
||||
A toolkit is a collection of related tools that can be used together
|
||||
to accomplish a specific task or work with a particular system.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
"""Get all tools in the toolkit.
|
||||
|
||||
Returns:
|
||||
List of tools contained in this toolkit.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user