diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 1cea8876c32..3caf190ef81 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -1,4 +1,4 @@ -"""Base for Tools.""" +"""Base classes and utilities for LangChain tools.""" from __future__ import annotations @@ -77,14 +77,30 @@ FILTERED_ARGS = ("run_manager", "callbacks") class SchemaAnnotationError(TypeError): - """Raised when 'args_schema' is missing or has an incorrect type annotation.""" + """Raised when args_schema is missing or has an incorrect type annotation.""" def _is_annotated_type(typ: type[Any]) -> bool: + """Check if a type is an Annotated type. + + Args: + typ: The type to check. + + Returns: + True if the type is an Annotated type, False otherwise. + """ return get_origin(typ) is typing.Annotated def _get_annotation_description(arg_type: type) -> str | None: + """Extract description from an Annotated type. + + Args: + arg_type: The type to extract description from. + + Returns: + The description string if found, None otherwise. + """ if _is_annotated_type(arg_type): annotated_args = get_args(arg_type) for annotation in annotated_args[1:]: @@ -100,7 +116,17 @@ def _get_filtered_args( filter_args: Sequence[str], include_injected: bool = True, ) -> dict: - """Get the arguments from a function's signature.""" + """Get filtered arguments from a function's signature. + + Args: + inferred_model: The Pydantic model inferred from the function. + func: The function to extract arguments from. + filter_args: Arguments to exclude from the result. + include_injected: Whether to include injected arguments. + + Returns: + Dictionary of filtered arguments with their schema definitions. + """ schema = inferred_model.model_json_schema()["properties"] valid_keys = signature(func).parameters return { @@ -115,9 +141,17 @@ def _get_filtered_args( def _parse_python_function_docstring( function: Callable, annotations: dict, *, error_on_invalid_docstring: bool = False ) -> tuple[str, dict]: - """Parse the function and argument descriptions from the docstring of a function. + """Parse function and argument descriptions from a docstring. Assumes the function docstring follows Google Python style guide. + + Args: + function: The function to parse the docstring from. + annotations: Type annotations for the function parameters. + error_on_invalid_docstring: Whether to raise an error on invalid docstring. + + Returns: + A tuple containing the function description and argument descriptions. """ docstring = inspect.getdoc(function) return _parse_google_docstring( @@ -130,7 +164,15 @@ def _parse_python_function_docstring( def _validate_docstring_args_against_annotations( arg_descriptions: dict, annotations: dict ) -> None: - """Raise error if docstring arg is not in type annotations.""" + """Validate that docstring arguments match function annotations. + + Args: + arg_descriptions: Arguments described in the docstring. + annotations: Type annotations from the function signature. + + Raises: + ValueError: If a docstring argument is not found in function signature. + """ for docstring_arg in arg_descriptions: if docstring_arg not in annotations: msg = f"Arg {docstring_arg} in docstring not found in function signature." @@ -143,7 +185,16 @@ def _infer_arg_descriptions( parse_docstring: bool = False, error_on_invalid_docstring: bool = False, ) -> tuple[str, dict]: - """Infer argument descriptions from a function's docstring.""" + """Infer argument descriptions from function docstring and annotations. + + Args: + fn: The function to infer descriptions from. + parse_docstring: Whether to parse the docstring for descriptions. + error_on_invalid_docstring: Whether to raise error on invalid docstring. + + Returns: + A tuple containing the function description and argument descriptions. + """ annotations = typing.get_type_hints(fn, include_extras=True) if parse_docstring: description, arg_descriptions = _parse_python_function_docstring( @@ -163,7 +214,15 @@ def _infer_arg_descriptions( def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool: - """Determine if a type annotation is a Pydantic model.""" + """Check if a type annotation is a Pydantic model. + + Args: + annotation: The type annotation to check. + pydantic_version: The Pydantic version to check against ("v1" or "v2"). + + Returns: + True if the annotation is a Pydantic model, False otherwise. + """ base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel try: return issubclass(annotation, base_model_class) @@ -174,7 +233,18 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo def _function_annotations_are_pydantic_v1( signature: inspect.Signature, func: Callable ) -> bool: - """Determine if all Pydantic annotations in a function signature are from V1.""" + """Check if all Pydantic annotations in a function are from V1. + + Args: + signature: The function signature to check. + func: The function being checked. + + Returns: + True if all Pydantic annotations are from V1, False otherwise. + + Raises: + NotImplementedError: If the function contains mixed V1 and V2 annotations. + """ any_v1_annotations = any( _is_pydantic_annotation(parameter.annotation, pydantic_version="v1") for parameter in signature.parameters.values() @@ -193,15 +263,11 @@ def _function_annotations_are_pydantic_v1( class _SchemaConfig: - """Configuration for the pydantic model. + """Configuration for Pydantic models generated from function signatures. - This is used to configure the pydantic model created from - a function's signature. - - Parameters: + Attributes: extra: Whether to allow extra fields in the model. arbitrary_types_allowed: Whether to allow arbitrary types in the model. - Defaults to True. """ extra: str = "forbid" @@ -309,12 +375,11 @@ def create_schema_from_function( class ToolException(Exception): # noqa: N818 - """Optional exception that tool throws when execution error occurs. + """Exception thrown when a tool execution error occurs. - When this exception is thrown, the agent will not stop working, - but it will handle the exception according to the handle_tool_error - variable of the tool, and the processing result will be returned - to the agent as observation, and printed in red on the console. + This exception allows tools to signal errors without stopping the agent. + The error is handled according to the tool's handle_tool_error setting, + and the result is returned as an observation to the agent. """ @@ -322,10 +387,21 @@ ArgsSchema = Union[TypeBaseModel, dict[str, Any]] class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]): - """Interface LangChain tools must implement.""" + """Base class for all LangChain tools. + + This abstract class defines the interface that all LangChain tools must implement. + Tools are components that can be called by agents to perform specific actions. + """ def __init_subclass__(cls, **kwargs: Any) -> None: - """Create the definition of the new tool class.""" + """Validate the tool class definition during subclass creation. + + Args: + **kwargs: Additional keyword arguments passed to the parent class. + + Raises: + SchemaAnnotationError: If args_schema has incorrect type annotation. + """ super().__init_subclass__(**kwargs) args_schema_type = cls.__annotations__.get("args_schema", None) @@ -444,13 +520,21 @@ class ChildTool(BaseTool): @property def is_single_input(self) -> bool: - """Whether the tool only accepts a single input.""" + """Check if the tool accepts only a single input argument. + + Returns: + True if the tool has only one input argument, False otherwise. + """ keys = {k for k in self.args if k != "kwargs"} return len(keys) == 1 @property def args(self) -> dict: - """The arguments of the tool.""" + """Get the tool's input arguments schema. + + Returns: + Dictionary containing the tool's argument properties. + """ if isinstance(self.args_schema, dict): json_schema = self.args_schema else: @@ -460,7 +544,11 @@ class ChildTool(BaseTool): @property def tool_call_schema(self) -> ArgsSchema: - """The schema for a tool call.""" + """Get the schema for tool calls, excluding injected arguments. + + Returns: + The schema that should be used for tool calls from language models. + """ if isinstance(self.args_schema, dict): if self.description: return { @@ -524,11 +612,19 @@ class ChildTool(BaseTool): def _parse_input( self, tool_input: Union[str, dict], tool_call_id: Optional[str] ) -> Union[str, dict[str, Any]]: - """Convert tool input to a pydantic model. + """Parse and validate tool input using the args schema. Args: - tool_input: The input to the tool. - tool_call_id: The id of the tool call. + tool_input: The raw input to the tool. + tool_call_id: The ID of the tool call, if available. + + Returns: + The parsed and validated input. + + Raises: + ValueError: If string input is provided with JSON schema or if + InjectedToolCallId is required but not provided. + NotImplementedError: If args_schema is not a supported type. """ input_args = self.args_schema if isinstance(tool_input, str): @@ -640,6 +736,18 @@ class ChildTool(BaseTool): def _to_args_and_kwargs( self, tool_input: Union[str, dict], tool_call_id: Optional[str] ) -> tuple[tuple, dict]: + """Convert tool input to positional and keyword arguments. + + Args: + tool_input: The input to the tool. + tool_call_id: The ID of the tool call, if available. + + Returns: + A tuple of (positional_args, keyword_args) for the tool. + + Raises: + TypeError: If the tool input type is invalid. + """ if ( self.args_schema is not None and isinstance(self.args_schema, type) @@ -892,11 +1000,27 @@ class ChildTool(BaseTool): @deprecated("0.1.47", alternative="invoke", removal="1.0") def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: - """Make tool callable.""" + """Make tool callable (deprecated). + + Args: + tool_input: The input to the tool. + callbacks: Callbacks to use during execution. + + Returns: + The tool's output. + """ return self.run(tool_input, callbacks=callbacks) def _is_tool_call(x: Any) -> bool: + """Check if the input is a tool call dictionary. + + Args: + x: The input to check. + + Returns: + True if the input is a tool call, False otherwise. + """ return isinstance(x, dict) and x.get("type") == "tool_call" @@ -907,6 +1031,18 @@ def _handle_validation_error( Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str] ], ) -> str: + """Handle validation errors based on the configured flag. + + Args: + e: The validation error that occurred. + flag: How to handle the error (bool, string, or callable). + + Returns: + The error message to return. + + Raises: + ValueError: If the flag type is unexpected. + """ if isinstance(flag, bool): content = "Tool input validation error" elif isinstance(flag, str): @@ -927,6 +1063,18 @@ def _handle_tool_error( *, flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], ) -> str: + """Handle tool execution errors based on the configured flag. + + Args: + e: The tool exception that occurred. + flag: How to handle the error (bool, string, or callable). + + Returns: + The error message to return. + + Raises: + ValueError: If the flag type is unexpected. + """ if isinstance(flag, bool): content = e.args[0] if e.args else "Tool execution error" elif isinstance(flag, str): @@ -947,6 +1095,16 @@ def _prep_run_args( config: Optional[RunnableConfig], **kwargs: Any, ) -> tuple[Union[str, dict], dict]: + """Prepare arguments for tool execution. + + Args: + value: The input value (string, dict, or ToolCall). + config: The runnable configuration. + **kwargs: Additional keyword arguments. + + Returns: + A tuple of (tool_input, run_kwargs). + """ config = ensure_config(config) if _is_tool_call(value): tool_call_id: Optional[str] = cast("ToolCall", value)["id"] @@ -976,6 +1134,18 @@ def _format_output( name: str, status: str, ) -> Union[ToolOutputMixin, Any]: + """Format tool output as a ToolMessage if appropriate. + + Args: + content: The main content of the tool output. + artifact: Any artifact data from the tool. + tool_call_id: The ID of the tool call. + name: The name of the tool. + status: The execution status. + + Returns: + The formatted output, either as a ToolMessage or the original content. + """ if isinstance(content, ToolOutputMixin) or tool_call_id is None: return content if not _is_message_content_type(content): @@ -990,14 +1160,32 @@ def _format_output( def _is_message_content_type(obj: Any) -> bool: - """Check for OpenAI or Anthropic format tool message content.""" + """Check if object is valid message content format. + + Validates content for OpenAI or Anthropic format tool messages. + + Args: + obj: The object to check. + + Returns: + True if the object is valid message content, False otherwise. + """ return isinstance(obj, str) or ( isinstance(obj, list) and all(_is_message_content_block(e) for e in obj) ) def _is_message_content_block(obj: Any) -> bool: - """Check for OpenAI or Anthropic format tool message content blocks.""" + """Check if object is a valid message content block. + + Validates content blocks for OpenAI or Anthropic format. + + Args: + obj: The object to check. + + Returns: + True if the object is a valid content block, False otherwise. + """ if isinstance(obj, str): return True if isinstance(obj, dict): @@ -1006,6 +1194,14 @@ def _is_message_content_block(obj: Any) -> bool: def _stringify(content: Any) -> str: + """Convert content to string, preferring JSON format. + + Args: + content: The content to stringify. + + Returns: + String representation of the content. + """ try: return json.dumps(content, ensure_ascii=False) except Exception: @@ -1013,6 +1209,14 @@ def _stringify(content: Any) -> str: def _get_type_hints(func: Callable) -> Optional[dict[str, type]]: + """Get type hints from a function, handling partial functions. + + Args: + func: The function to get type hints from. + + Returns: + Dictionary of type hints, or None if extraction fails. + """ if isinstance(func, functools.partial): func = func.func try: @@ -1022,6 +1226,14 @@ def _get_type_hints(func: Callable) -> Optional[dict[str, type]]: def _get_runnable_config_param(func: Callable) -> Optional[str]: + """Find the parameter name for RunnableConfig in a function. + + Args: + func: The function to check. + + Returns: + The parameter name for RunnableConfig, or None if not found. + """ type_hints = _get_type_hints(func) if not type_hints: return None @@ -1032,30 +1244,52 @@ def _get_runnable_config_param(func: Callable) -> Optional[str]: class InjectedToolArg: - """Annotation for a Tool arg that is **not** meant to be generated by a model.""" + """Annotation for tool arguments that are injected at runtime. + + Tool arguments annotated with this class are not included in the tool + schema sent to language models and are instead injected during execution. + """ class InjectedToolCallId(InjectedToolArg): - """Annotation for injecting the tool_call_id. + """Annotation for injecting the tool call ID. + + This annotation is used to mark a tool parameter that should receive + the tool call ID at runtime. Example: - ..code-block:: python + ```python + from typing_extensions import Annotated + from langchain_core.messages import ToolMessage + from langchain_core.tools import tool, InjectedToolCallId - from typing_extensions import Annotated - - from langchain_core.messages import ToolMessage - from langchain_core.tools import tool, InjectedToolCallID - - @tool - def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallID]) -> ToolMessage: - \"\"\"Return x.\"\"\" - return ToolMessage(str(x), artifact=x, name="foo", tool_call_id=tool_call_id) - """ # noqa: E501 + @tool + def foo( + x: int, tool_call_id: Annotated[str, InjectedToolCallId] + ) -> ToolMessage: + \"\"\"Return x.\"\"\" + return ToolMessage( + str(x), + artifact=x, + name="foo", + tool_call_id=tool_call_id + ) + ``` + """ def _is_injected_arg_type( type_: type, injected_type: Optional[type[InjectedToolArg]] = None ) -> bool: + """Check if a type annotation indicates an injected argument. + + Args: + type_: The type annotation to check. + injected_type: The specific injected type to check for. + + Returns: + True if the type is an injected argument, False otherwise. + """ injected_type = injected_type or InjectedToolArg return any( isinstance(arg, injected_type) @@ -1138,6 +1372,16 @@ def _replace_type_vars( *, default_to_bound: bool = True, ) -> type: + """Replace TypeVars in a type annotation with concrete types. + + Args: + type_: The type annotation to process. + generic_map: Mapping of TypeVars to concrete types. + default_to_bound: Whether to use TypeVar bounds as defaults. + + Returns: + The type with TypeVars replaced. + """ generic_map = generic_map or {} if isinstance(type_, TypeVar): if type_ in generic_map: @@ -1155,8 +1399,16 @@ def _replace_type_vars( class BaseToolkit(BaseModel, ABC): - """Base Toolkit representing a collection of related tools.""" + """Base class for toolkits containing related tools. + + A toolkit is a collection of related tools that can be used together + to accomplish a specific task or work with a particular system. + """ @abstractmethod def get_tools(self) -> list[BaseTool]: - """Get the tools in the toolkit.""" + """Get all tools in the toolkit. + + Returns: + List of tools contained in this toolkit. + """