core[patch]: Add doc-strings to tools/base.py (#31684)

Add doc-strings
This commit is contained in:
Eugene Yurtsev 2025-06-20 11:16:57 -04:00 committed by GitHub
parent 5d0bea8378
commit 2842e0c8c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,4 @@
"""Base for Tools."""
"""Base classes and utilities for LangChain tools."""
from __future__ import annotations
@ -77,14 +77,30 @@ FILTERED_ARGS = ("run_manager", "callbacks")
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
"""Raised when args_schema is missing or has an incorrect type annotation."""
def _is_annotated_type(typ: type[Any]) -> bool:
"""Check if a type is an Annotated type.
Args:
typ: The type to check.
Returns:
True if the type is an Annotated type, False otherwise.
"""
return get_origin(typ) is typing.Annotated
def _get_annotation_description(arg_type: type) -> str | None:
"""Extract description from an Annotated type.
Args:
arg_type: The type to extract description from.
Returns:
The description string if found, None otherwise.
"""
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
for annotation in annotated_args[1:]:
@ -100,7 +116,17 @@ def _get_filtered_args(
filter_args: Sequence[str],
include_injected: bool = True,
) -> dict:
"""Get the arguments from a function's signature."""
"""Get filtered arguments from a function's signature.
Args:
inferred_model: The Pydantic model inferred from the function.
func: The function to extract arguments from.
filter_args: Arguments to exclude from the result.
include_injected: Whether to include injected arguments.
Returns:
Dictionary of filtered arguments with their schema definitions.
"""
schema = inferred_model.model_json_schema()["properties"]
valid_keys = signature(func).parameters
return {
@ -115,9 +141,17 @@ def _get_filtered_args(
def _parse_python_function_docstring(
function: Callable, annotations: dict, *, error_on_invalid_docstring: bool = False
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
"""Parse function and argument descriptions from a docstring.
Assumes the function docstring follows Google Python style guide.
Args:
function: The function to parse the docstring from.
annotations: Type annotations for the function parameters.
error_on_invalid_docstring: Whether to raise an error on invalid docstring.
Returns:
A tuple containing the function description and argument descriptions.
"""
docstring = inspect.getdoc(function)
return _parse_google_docstring(
@ -130,7 +164,15 @@ def _parse_python_function_docstring(
def _validate_docstring_args_against_annotations(
arg_descriptions: dict, annotations: dict
) -> None:
"""Raise error if docstring arg is not in type annotations."""
"""Validate that docstring arguments match function annotations.
Args:
arg_descriptions: Arguments described in the docstring.
annotations: Type annotations from the function signature.
Raises:
ValueError: If a docstring argument is not found in function signature.
"""
for docstring_arg in arg_descriptions:
if docstring_arg not in annotations:
msg = f"Arg {docstring_arg} in docstring not found in function signature."
@ -143,7 +185,16 @@ def _infer_arg_descriptions(
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
) -> tuple[str, dict]:
"""Infer argument descriptions from a function's docstring."""
"""Infer argument descriptions from function docstring and annotations.
Args:
fn: The function to infer descriptions from.
parse_docstring: Whether to parse the docstring for descriptions.
error_on_invalid_docstring: Whether to raise error on invalid docstring.
Returns:
A tuple containing the function description and argument descriptions.
"""
annotations = typing.get_type_hints(fn, include_extras=True)
if parse_docstring:
description, arg_descriptions = _parse_python_function_docstring(
@ -163,7 +214,15 @@ def _infer_arg_descriptions(
def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool:
"""Determine if a type annotation is a Pydantic model."""
"""Check if a type annotation is a Pydantic model.
Args:
annotation: The type annotation to check.
pydantic_version: The Pydantic version to check against ("v1" or "v2").
Returns:
True if the annotation is a Pydantic model, False otherwise.
"""
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
try:
return issubclass(annotation, base_model_class)
@ -174,7 +233,18 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
def _function_annotations_are_pydantic_v1(
signature: inspect.Signature, func: Callable
) -> bool:
"""Determine if all Pydantic annotations in a function signature are from V1."""
"""Check if all Pydantic annotations in a function are from V1.
Args:
signature: The function signature to check.
func: The function being checked.
Returns:
True if all Pydantic annotations are from V1, False otherwise.
Raises:
NotImplementedError: If the function contains mixed V1 and V2 annotations.
"""
any_v1_annotations = any(
_is_pydantic_annotation(parameter.annotation, pydantic_version="v1")
for parameter in signature.parameters.values()
@ -193,15 +263,11 @@ def _function_annotations_are_pydantic_v1(
class _SchemaConfig:
"""Configuration for the pydantic model.
"""Configuration for Pydantic models generated from function signatures.
This is used to configure the pydantic model created from
a function's signature.
Parameters:
Attributes:
extra: Whether to allow extra fields in the model.
arbitrary_types_allowed: Whether to allow arbitrary types in the model.
Defaults to True.
"""
extra: str = "forbid"
@ -309,12 +375,11 @@ def create_schema_from_function(
class ToolException(Exception): # noqa: N818
"""Optional exception that tool throws when execution error occurs.
"""Exception thrown when a tool execution error occurs.
When this exception is thrown, the agent will not stop working,
but it will handle the exception according to the handle_tool_error
variable of the tool, and the processing result will be returned
to the agent as observation, and printed in red on the console.
This exception allows tools to signal errors without stopping the agent.
The error is handled according to the tool's handle_tool_error setting,
and the result is returned as an observation to the agent.
"""
@ -322,10 +387,21 @@ ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
"""Interface LangChain tools must implement."""
"""Base class for all LangChain tools.
This abstract class defines the interface that all LangChain tools must implement.
Tools are components that can be called by agents to perform specific actions.
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Create the definition of the new tool class."""
"""Validate the tool class definition during subclass creation.
Args:
**kwargs: Additional keyword arguments passed to the parent class.
Raises:
SchemaAnnotationError: If args_schema has incorrect type annotation.
"""
super().__init_subclass__(**kwargs)
args_schema_type = cls.__annotations__.get("args_schema", None)
@ -444,13 +520,21 @@ class ChildTool(BaseTool):
@property
def is_single_input(self) -> bool:
"""Whether the tool only accepts a single input."""
"""Check if the tool accepts only a single input argument.
Returns:
True if the tool has only one input argument, False otherwise.
"""
keys = {k for k in self.args if k != "kwargs"}
return len(keys) == 1
@property
def args(self) -> dict:
"""The arguments of the tool."""
"""Get the tool's input arguments schema.
Returns:
Dictionary containing the tool's argument properties.
"""
if isinstance(self.args_schema, dict):
json_schema = self.args_schema
else:
@ -460,7 +544,11 @@ class ChildTool(BaseTool):
@property
def tool_call_schema(self) -> ArgsSchema:
"""The schema for a tool call."""
"""Get the schema for tool calls, excluding injected arguments.
Returns:
The schema that should be used for tool calls from language models.
"""
if isinstance(self.args_schema, dict):
if self.description:
return {
@ -524,11 +612,19 @@ class ChildTool(BaseTool):
def _parse_input(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> Union[str, dict[str, Any]]:
"""Convert tool input to a pydantic model.
"""Parse and validate tool input using the args schema.
Args:
tool_input: The input to the tool.
tool_call_id: The id of the tool call.
tool_input: The raw input to the tool.
tool_call_id: The ID of the tool call, if available.
Returns:
The parsed and validated input.
Raises:
ValueError: If string input is provided with JSON schema or if
InjectedToolCallId is required but not provided.
NotImplementedError: If args_schema is not a supported type.
"""
input_args = self.args_schema
if isinstance(tool_input, str):
@ -640,6 +736,18 @@ class ChildTool(BaseTool):
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
"""Convert tool input to positional and keyword arguments.
Args:
tool_input: The input to the tool.
tool_call_id: The ID of the tool call, if available.
Returns:
A tuple of (positional_args, keyword_args) for the tool.
Raises:
TypeError: If the tool input type is invalid.
"""
if (
self.args_schema is not None
and isinstance(self.args_schema, type)
@ -892,11 +1000,27 @@ class ChildTool(BaseTool):
@deprecated("0.1.47", alternative="invoke", removal="1.0")
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
"""Make tool callable."""
"""Make tool callable (deprecated).
Args:
tool_input: The input to the tool.
callbacks: Callbacks to use during execution.
Returns:
The tool's output.
"""
return self.run(tool_input, callbacks=callbacks)
def _is_tool_call(x: Any) -> bool:
"""Check if the input is a tool call dictionary.
Args:
x: The input to check.
Returns:
True if the input is a tool call, False otherwise.
"""
return isinstance(x, dict) and x.get("type") == "tool_call"
@ -907,6 +1031,18 @@ def _handle_validation_error(
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> str:
"""Handle validation errors based on the configured flag.
Args:
e: The validation error that occurred.
flag: How to handle the error (bool, string, or callable).
Returns:
The error message to return.
Raises:
ValueError: If the flag type is unexpected.
"""
if isinstance(flag, bool):
content = "Tool input validation error"
elif isinstance(flag, str):
@ -927,6 +1063,18 @@ def _handle_tool_error(
*,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str:
"""Handle tool execution errors based on the configured flag.
Args:
e: The tool exception that occurred.
flag: How to handle the error (bool, string, or callable).
Returns:
The error message to return.
Raises:
ValueError: If the flag type is unexpected.
"""
if isinstance(flag, bool):
content = e.args[0] if e.args else "Tool execution error"
elif isinstance(flag, str):
@ -947,6 +1095,16 @@ def _prep_run_args(
config: Optional[RunnableConfig],
**kwargs: Any,
) -> tuple[Union[str, dict], dict]:
"""Prepare arguments for tool execution.
Args:
value: The input value (string, dict, or ToolCall).
config: The runnable configuration.
**kwargs: Additional keyword arguments.
Returns:
A tuple of (tool_input, run_kwargs).
"""
config = ensure_config(config)
if _is_tool_call(value):
tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
@ -976,6 +1134,18 @@ def _format_output(
name: str,
status: str,
) -> Union[ToolOutputMixin, Any]:
"""Format tool output as a ToolMessage if appropriate.
Args:
content: The main content of the tool output.
artifact: Any artifact data from the tool.
tool_call_id: The ID of the tool call.
name: The name of the tool.
status: The execution status.
Returns:
The formatted output, either as a ToolMessage or the original content.
"""
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
return content
if not _is_message_content_type(content):
@ -990,14 +1160,32 @@ def _format_output(
def _is_message_content_type(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content."""
"""Check if object is valid message content format.
Validates content for OpenAI or Anthropic format tool messages.
Args:
obj: The object to check.
Returns:
True if the object is valid message content, False otherwise.
"""
return isinstance(obj, str) or (
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
)
def _is_message_content_block(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content blocks."""
"""Check if object is a valid message content block.
Validates content blocks for OpenAI or Anthropic format.
Args:
obj: The object to check.
Returns:
True if the object is a valid content block, False otherwise.
"""
if isinstance(obj, str):
return True
if isinstance(obj, dict):
@ -1006,6 +1194,14 @@ def _is_message_content_block(obj: Any) -> bool:
def _stringify(content: Any) -> str:
"""Convert content to string, preferring JSON format.
Args:
content: The content to stringify.
Returns:
String representation of the content.
"""
try:
return json.dumps(content, ensure_ascii=False)
except Exception:
@ -1013,6 +1209,14 @@ def _stringify(content: Any) -> str:
def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
"""Get type hints from a function, handling partial functions.
Args:
func: The function to get type hints from.
Returns:
Dictionary of type hints, or None if extraction fails.
"""
if isinstance(func, functools.partial):
func = func.func
try:
@ -1022,6 +1226,14 @@ def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
def _get_runnable_config_param(func: Callable) -> Optional[str]:
"""Find the parameter name for RunnableConfig in a function.
Args:
func: The function to check.
Returns:
The parameter name for RunnableConfig, or None if not found.
"""
type_hints = _get_type_hints(func)
if not type_hints:
return None
@ -1032,30 +1244,52 @@ def _get_runnable_config_param(func: Callable) -> Optional[str]:
class InjectedToolArg:
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
"""Annotation for tool arguments that are injected at runtime.
Tool arguments annotated with this class are not included in the tool
schema sent to language models and are instead injected during execution.
"""
class InjectedToolCallId(InjectedToolArg):
"""Annotation for injecting the tool_call_id.
"""Annotation for injecting the tool call ID.
This annotation is used to mark a tool parameter that should receive
the tool call ID at runtime.
Example:
..code-block:: python
```python
from typing_extensions import Annotated
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool, InjectedToolCallId
from typing_extensions import Annotated
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool, InjectedToolCallID
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallID]) -> ToolMessage:
\"\"\"Return x.\"\"\"
return ToolMessage(str(x), artifact=x, name="foo", tool_call_id=tool_call_id)
""" # noqa: E501
@tool
def foo(
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
) -> ToolMessage:
\"\"\"Return x.\"\"\"
return ToolMessage(
str(x),
artifact=x,
name="foo",
tool_call_id=tool_call_id
)
```
"""
def _is_injected_arg_type(
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
) -> bool:
"""Check if a type annotation indicates an injected argument.
Args:
type_: The type annotation to check.
injected_type: The specific injected type to check for.
Returns:
True if the type is an injected argument, False otherwise.
"""
injected_type = injected_type or InjectedToolArg
return any(
isinstance(arg, injected_type)
@ -1138,6 +1372,16 @@ def _replace_type_vars(
*,
default_to_bound: bool = True,
) -> type:
"""Replace TypeVars in a type annotation with concrete types.
Args:
type_: The type annotation to process.
generic_map: Mapping of TypeVars to concrete types.
default_to_bound: Whether to use TypeVar bounds as defaults.
Returns:
The type with TypeVars replaced.
"""
generic_map = generic_map or {}
if isinstance(type_, TypeVar):
if type_ in generic_map:
@ -1155,8 +1399,16 @@ def _replace_type_vars(
class BaseToolkit(BaseModel, ABC):
"""Base Toolkit representing a collection of related tools."""
"""Base class for toolkits containing related tools.
A toolkit is a collection of related tools that can be used together
to accomplish a specific task or work with a particular system.
"""
@abstractmethod
def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit."""
"""Get all tools in the toolkit.
Returns:
List of tools contained in this toolkit.
"""