core[patch]: docstrings langchain_core/ files update (#24285)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-07-16 06:21:51 -07:00 committed by GitHub
parent 7aeaa1974d
commit 198b85334f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 345 additions and 109 deletions

View File

@ -65,12 +65,15 @@ class AgentAction(Serializable):
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable.""" """Return whether or not the class is serializable.
Default is True.
"""
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object.
Default is ["langchain", "schema", "agent"]."""
return ["langchain", "schema", "agent"] return ["langchain", "schema", "agent"]
@property @property
@ -80,7 +83,7 @@ class AgentAction(Serializable):
class AgentActionMessageLog(AgentAction): class AgentActionMessageLog(AgentAction):
"""A representation of an action to be executed by an agent. """Representation of an action to be executed by an agent.
This is similar to AgentAction, but includes a message log consisting of This is similar to AgentAction, but includes a message log consisting of
chat messages. This is useful when working with ChatModels, and is used chat messages. This is useful when working with ChatModels, and is used
@ -102,7 +105,7 @@ class AgentActionMessageLog(AgentAction):
class AgentStep(Serializable): class AgentStep(Serializable):
"""The result of running an AgentAction.""" """Result of running an AgentAction."""
action: AgentAction action: AgentAction
"""The AgentAction that was executed.""" """The AgentAction that was executed."""
@ -111,12 +114,12 @@ class AgentStep(Serializable):
@property @property
def messages(self) -> Sequence[BaseMessage]: def messages(self) -> Sequence[BaseMessage]:
"""Return the messages that correspond to this observation.""" """Messages that correspond to this observation."""
return _convert_agent_observation_to_messages(self.action, self.observation) return _convert_agent_observation_to_messages(self.action, self.observation)
class AgentFinish(Serializable): class AgentFinish(Serializable):
"""The final return value of an ActionAgent. """Final return value of an ActionAgent.
Agents return an AgentFinish when they have reached a stopping condition. Agents return an AgentFinish when they have reached a stopping condition.
""" """
@ -148,7 +151,7 @@ class AgentFinish(Serializable):
@property @property
def messages(self) -> Sequence[BaseMessage]: def messages(self) -> Sequence[BaseMessage]:
"""Return the messages that correspond to this observation.""" """Messages that correspond to this observation."""
return [AIMessage(content=self.log)] return [AIMessage(content=self.log)]
@ -180,6 +183,7 @@ def _convert_agent_observation_to_messages(
Args: Args:
agent_action: Agent action to convert. agent_action: Agent action to convert.
observation: Observation to convert to a message.
Returns: Returns:
AIMessage that corresponds to the original tool invocation. AIMessage that corresponds to the original tool invocation.
@ -196,11 +200,11 @@ def _create_function_message(
"""Convert agent action and observation into a function message. """Convert agent action and observation into a function message.
Args: Args:
agent_action: the tool invocation request from the agent agent_action: the tool invocation request from the agent.
observation: the result of the tool invocation observation: the result of the tool invocation.
Returns: Returns:
FunctionMessage that corresponds to the original tool invocation FunctionMessage that corresponds to the original tool invocation.
""" """
if not isinstance(observation, str): if not isinstance(observation, str):
try: try:

View File

@ -44,7 +44,7 @@ class BaseCache(ABC):
The default implementation of the async methods is to run the synchronous The default implementation of the async methods is to run the synchronous
method in an executor. It's recommended to override the async methods method in an executor. It's recommended to override the async methods
and provide an async implementations to avoid unnecessary overhead. and provide async implementations to avoid unnecessary overhead.
""" """
@abstractmethod @abstractmethod
@ -152,6 +152,10 @@ class InMemoryCache(BaseCache):
maxsize: The maximum number of items to store in the cache. maxsize: The maximum number of items to store in the cache.
If None, the cache has no maximum size. If None, the cache has no maximum size.
If the cache exceeds the maximum size, the oldest items are removed. If the cache exceeds the maximum size, the oldest items are removed.
Default is None.
Raises:
ValueError: If maxsize is less than or equal to 0.
""" """
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
if maxsize is not None and maxsize <= 0: if maxsize is not None and maxsize <= 0:

View File

@ -116,7 +116,7 @@ class BaseChatMessageHistory(ABC):
This method may be deprecated in a future release. This method may be deprecated in a future release.
Args: Args:
message: The human message to add message: The human message to add to the store.
""" """
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
self.add_message(message) self.add_message(message)
@ -200,22 +200,38 @@ class BaseChatMessageHistory(ABC):
class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel): class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history. """In memory implementation of chat message history.
Stores messages in an in memory list. Stores messages in a memory list.
""" """
messages: List[BaseMessage] = Field(default_factory=list) messages: List[BaseMessage] = Field(default_factory=list)
"""A list of messages stored in memory.""" """A list of messages stored in memory."""
async def aget_messages(self) -> List[BaseMessage]: async def aget_messages(self) -> List[BaseMessage]:
"""Async version of getting messages.""" """Async version of getting messages.
Can over-ride this method to provide an efficient async implementation.
In general, fetching messages may involve IO to the underlying
persistence layer.
Returns:
List of messages.
"""
return self.messages return self.messages
def add_message(self, message: BaseMessage) -> None: def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store.""" """Add a self-created message to the store.
Args:
message: The message to add.
"""
self.messages.append(message) self.messages.append(message)
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Async add messages to the store""" """Async add messages to the store.
Args:
messages: The messages to add.
"""
self.add_messages(messages) self.add_messages(messages)
def clear(self) -> None: def clear(self) -> None:

View File

@ -4,7 +4,11 @@ from functools import lru_cache
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_runtime_environment() -> dict: def get_runtime_environment() -> dict:
"""Get information about the LangChain runtime environment.""" """Get information about the LangChain runtime environment.
Returns:
A dictionary with information about the runtime environment.
"""
# Lazy import to avoid circular imports # Lazy import to avoid circular imports
from langchain_core import __version__ from langchain_core import __version__

View File

@ -22,13 +22,14 @@ class OutputParserException(ValueError, LangChainException):
Parameters: Parameters:
error: The error that's being re-raised or an error message. error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a observation: String explanation of error which can be passed to a
model to try and remediate the issue. model to try and remediate the issue. Defaults to None.
llm_output: String model output which is error-ing. llm_output: String model output which is error-ing.
Defaults to None.
send_to_llm: Whether to send the observation and llm_output back to an Agent send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct structured, in the hopes that it will update the output to the correct
format. format. Defaults to False.
""" """
def __init__( def __init__(

View File

@ -18,7 +18,11 @@ _llm_cache: Optional["BaseCache"] = None
def set_verbose(value: bool) -> None: def set_verbose(value: bool) -> None:
"""Set a new value for the `verbose` global setting.""" """Set a new value for the `verbose` global setting.
Args:
value: The new value for the `verbose` global setting.
"""
try: try:
import langchain # type: ignore[import] import langchain # type: ignore[import]
@ -46,7 +50,11 @@ def set_verbose(value: bool) -> None:
def get_verbose() -> bool: def get_verbose() -> bool:
"""Get the value of the `verbose` global setting.""" """Get the value of the `verbose` global setting.
Returns:
The value of the `verbose` global setting.
"""
try: try:
import langchain # type: ignore[import] import langchain # type: ignore[import]
@ -79,7 +87,11 @@ def get_verbose() -> bool:
def set_debug(value: bool) -> None: def set_debug(value: bool) -> None:
"""Set a new value for the `debug` global setting.""" """Set a new value for the `debug` global setting.
Args:
value: The new value for the `debug` global setting.
"""
try: try:
import langchain # type: ignore[import] import langchain # type: ignore[import]
@ -105,7 +117,11 @@ def set_debug(value: bool) -> None:
def get_debug() -> bool: def get_debug() -> bool:
"""Get the value of the `debug` global setting.""" """Get the value of the `debug` global setting.
Returns:
The value of the `debug` global setting.
"""
try: try:
import langchain # type: ignore[import] import langchain # type: ignore[import]
@ -168,7 +184,11 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
def get_llm_cache() -> "BaseCache": def get_llm_cache() -> "BaseCache":
"""Get the value of the `llm_cache` global setting.""" """Get the value of the `llm_cache` global setting.
Returns:
The value of the `llm_cache` global setting.
"""
try: try:
import langchain # type: ignore[import] import langchain # type: ignore[import]

View File

@ -62,13 +62,20 @@ class BaseMemory(Serializable, ABC):
"""Return key-value pairs given the text input to the chain. """Return key-value pairs given the text input to the chain.
Args: Args:
inputs: The inputs to the chain.""" inputs: The inputs to the chain.
Returns:
A dictionary of key-value pairs.
"""
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Async return key-value pairs given the text input to the chain. """Async return key-value pairs given the text input to the chain.
Args: Args:
inputs: The inputs to the chain. inputs: The inputs to the chain.
Returns:
A dictionary of key-value pairs.
""" """
return await run_in_executor(None, self.load_memory_variables, inputs) return await run_in_executor(None, self.load_memory_variables, inputs)

View File

@ -29,12 +29,15 @@ class PromptValue(Serializable, ABC):
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable.""" """Return whether this class is serializable. Defaults to True."""
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "schema", "prompt"].
"""
return ["langchain", "schema", "prompt"] return ["langchain", "schema", "prompt"]
@abstractmethod @abstractmethod
@ -55,7 +58,10 @@ class StringPromptValue(PromptValue):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "base"].
"""
return ["langchain", "prompts", "base"] return ["langchain", "prompts", "base"]
def to_string(self) -> str: def to_string(self) -> str:
@ -86,7 +92,10 @@ class ChatPromptValue(PromptValue):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "chat"].
"""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@ -94,7 +103,8 @@ class ImageURL(TypedDict, total=False):
"""Image URL.""" """Image URL."""
detail: Literal["auto", "low", "high"] detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image.""" """Specifies the detail level of the image. Defaults to "auto".
Can be "auto", "low", or "high"."""
url: str url: str
"""Either a URL of the image or the base64 encoded image data.""" """Either a URL of the image or the base64 encoded image data."""
@ -127,5 +137,8 @@ class ChatPromptValueConcrete(ChatPromptValue):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "chat"].
"""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]

View File

@ -53,14 +53,13 @@ RetrieverOutputLike = Runnable[Any, RetrieverOutput]
class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""Abstract base class for a Document retrieval system. """Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return A retrieval system is defined as something that can take string queries and return
the most 'relevant' Documents from some source. the most 'relevant' Documents from some source.
Usage: Usage:
A retriever follows the standard Runnable interface, and should be used A retriever follows the standard Runnable interface, and should be used
via the standard runnable methods of `invoke`, `ainvoke`, `batch`, `abatch`. via the standard Runnable methods of `invoke`, `ainvoke`, `batch`, `abatch`.
Implementation: Implementation:
@ -89,7 +88,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
\"\"\"(Optional) async native implementation.\"\"\" \"\"\"(Optional) async native implementation.\"\"\"
return self.docs[:self.k] return self.docs[:self.k]
Example: A simple retriever based on a scitkit learn vectorizer Example: A simple retriever based on a scikit-learn vectorizer
.. code-block:: python .. code-block:: python
@ -178,12 +177,12 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Main entry point for synchronous retriever invocations. Main entry point for synchronous retriever invocations.
Args: Args:
input: The query string input: The query string.
config: Configuration for the retriever config: Configuration for the retriever. Defaults to None.
**kwargs: Additional arguments to pass to the retriever **kwargs: Additional arguments to pass to the retriever.
Returns: Returns:
List of relevant documents List of relevant documents.
Examples: Examples:
@ -237,12 +236,12 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Main entry point for asynchronous retriever invocations. Main entry point for asynchronous retriever invocations.
Args: Args:
input: The query string input: The query string.
config: Configuration for the retriever config: Configuration for the retriever. Defaults to None.
**kwargs: Additional arguments to pass to the retriever **kwargs: Additional arguments to pass to the retriever.
Returns: Returns:
List of relevant documents List of relevant documents.
Examples: Examples:
@ -292,10 +291,10 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""Get documents relevant to a query. """Get documents relevant to a query.
Args: Args:
query: String to find relevant documents for query: String to find relevant documents for.
run_manager: The callback handler to use run_manager: The callback handler to use.
Returns: Returns:
List of relevant documents List of relevant documents.
""" """
async def _aget_relevant_documents( async def _aget_relevant_documents(
@ -333,18 +332,21 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
`get_relevant_documents directly`. `get_relevant_documents directly`.
Args: Args:
query: string to find relevant documents for query: string to find relevant documents for.
callbacks: Callback manager or list of callbacks callbacks: Callback manager or list of callbacks. Defaults to None.
tags: Optional list of tags associated with the retriever. Defaults to None tags: Optional list of tags associated with the retriever.
These tags will be associated with each call to this retriever, These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None Defaults to None.
metadata: Optional metadata associated with the retriever.
This metadata will be associated with each call to this retriever, This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
run_name: Optional name for the run. Defaults to None.
run_name: Optional name for the run. Defaults to None.
**kwargs: Additional arguments to pass to the retriever.
Returns: Returns:
List of relevant documents List of relevant documents.
""" """
config: RunnableConfig = {} config: RunnableConfig = {}
if callbacks: if callbacks:
@ -374,18 +376,21 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
`aget_relevant_documents directly`. `aget_relevant_documents directly`.
Args: Args:
query: string to find relevant documents for query: string to find relevant documents for.
callbacks: Callback manager or list of callbacks callbacks: Callback manager or list of callbacks.
tags: Optional list of tags associated with the retriever. Defaults to None tags: Optional list of tags associated with the retriever.
These tags will be associated with each call to this retriever, These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None Defaults to None.
metadata: Optional metadata associated with the retriever.
This metadata will be associated with each call to this retriever, This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
run_name: Optional name for the run. Defaults to None.
run_name: Optional name for the run. Defaults to None.
**kwargs: Additional arguments to pass to the retriever.
Returns: Returns:
List of relevant documents List of relevant documents.
""" """
config: RunnableConfig = {} config: RunnableConfig = {}
if callbacks: if callbacks:

View File

@ -150,7 +150,6 @@ class BaseStore(Generic[K, V], ABC):
Yields: Yields:
Iterator[K | str]: An iterator over keys that match the given prefix. Iterator[K | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either K or str This method is allowed to return an iterator over either K or str
depending on what makes more sense for the given store. depending on what makes more sense for the given store.
""" """
@ -165,7 +164,6 @@ class BaseStore(Generic[K, V], ABC):
Yields: Yields:
Iterator[K | str]: An iterator over keys that match the given prefix. Iterator[K | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either K or str This method is allowed to return an iterator over either K or str
depending on what makes more sense for the given store. depending on what makes more sense for the given store.
""" """

View File

@ -10,10 +10,12 @@ from langchain_core.pydantic_v1 import BaseModel
class Visitor(ABC): class Visitor(ABC):
"""Defines interface for IR translation using visitor pattern.""" """Defines interface for IR translation using a visitor pattern."""
allowed_comparators: Optional[Sequence[Comparator]] = None allowed_comparators: Optional[Sequence[Comparator]] = None
"""Allowed comparators for the visitor."""
allowed_operators: Optional[Sequence[Operator]] = None allowed_operators: Optional[Sequence[Operator]] = None
"""Allowed operators for the visitor."""
def _validate_func(self, func: Union[Operator, Comparator]) -> None: def _validate_func(self, func: Union[Operator, Comparator]) -> None:
if isinstance(func, Operator) and self.allowed_operators is not None: if isinstance(func, Operator) and self.allowed_operators is not None:
@ -31,15 +33,27 @@ class Visitor(ABC):
@abstractmethod @abstractmethod
def visit_operation(self, operation: Operation) -> Any: def visit_operation(self, operation: Operation) -> Any:
"""Translate an Operation.""" """Translate an Operation.
Args:
operation: Operation to translate.
"""
@abstractmethod @abstractmethod
def visit_comparison(self, comparison: Comparison) -> Any: def visit_comparison(self, comparison: Comparison) -> Any:
"""Translate a Comparison.""" """Translate a Comparison.
Args:
comparison: Comparison to translate.
"""
@abstractmethod @abstractmethod
def visit_structured_query(self, structured_query: StructuredQuery) -> Any: def visit_structured_query(self, structured_query: StructuredQuery) -> Any:
"""Translate a StructuredQuery.""" """Translate a StructuredQuery.
Args:
structured_query: StructuredQuery to translate.
"""
def _to_snake_case(name: str) -> str: def _to_snake_case(name: str) -> str:
@ -60,10 +74,10 @@ class Expr(BaseModel):
"""Accept a visitor. """Accept a visitor.
Args: Args:
visitor: visitor to accept visitor: visitor to accept.
Returns: Returns:
result of visiting result of visiting.
""" """
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")( return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
self self
@ -98,7 +112,13 @@ class FilterDirective(Expr, ABC):
class Comparison(FilterDirective): class Comparison(FilterDirective):
"""Comparison to a value.""" """Comparison to a value.
Parameters:
comparator: The comparator to use.
attribute: The attribute to compare.
value: The value to compare to.
"""
comparator: Comparator comparator: Comparator
attribute: str attribute: str
@ -113,7 +133,12 @@ class Comparison(FilterDirective):
class Operation(FilterDirective): class Operation(FilterDirective):
"""Llogical operation over other directives.""" """Logical operation over other directives.
Parameters:
operator: The operator to use.
arguments: The arguments to the operator.
"""
operator: Operator operator: Operator
arguments: List[FilterDirective] arguments: List[FilterDirective]

View File

@ -6,7 +6,11 @@ from typing import Sequence
def print_sys_info(*, additional_pkgs: Sequence[str] = tuple()) -> None: def print_sys_info(*, additional_pkgs: Sequence[str] = tuple()) -> None:
"""Print information about the environment for debugging purposes.""" """Print information about the environment for debugging purposes.
Args:
additional_pkgs: Additional packages to include in the output.
"""
import pkgutil import pkgutil
import platform import platform
import sys import sys

View File

@ -249,7 +249,16 @@ def _infer_arg_descriptions(
class _SchemaConfig: class _SchemaConfig:
"""Configuration for the pydantic model.""" """Configuration for the pydantic model.
This is used to configure the pydantic model created from
a function's signature.
Parameters:
extra: Whether to allow extra fields in the model.
arbitrary_types_allowed: Whether to allow arbitrary types in the model.
Defaults to True.
"""
extra: Any = Extra.forbid extra: Any = Extra.forbid
arbitrary_types_allowed: bool = True arbitrary_types_allowed: bool = True
@ -265,15 +274,18 @@ def create_schema_from_function(
) -> Type[BaseModel]: ) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature. """Create a pydantic schema from a function's signature.
Args: Args:
model_name: Name to assign to the generated pydandic schema model_name: Name to assign to the generated pydantic schema.
func: Function to generate the schema from func: Function to generate the schema from.
filter_args: Optional list of arguments to exclude from the schema filter_args: Optional list of arguments to exclude from the schema.
Defaults to FILTERED_ARGS.
parse_docstring: Whether to parse the function's docstring for descriptions parse_docstring: Whether to parse the function's docstring for descriptions
for each argument. for each argument. Defaults to False.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings. whether to raise ValueError on invalid Google Style docstrings.
Defaults to False.
Returns: Returns:
A pydantic model with the same arguments as the function A pydantic model with the same arguments as the function.
""" """
# https://docs.pydantic.dev/latest/usage/validation_decorator/ # https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
@ -348,8 +360,9 @@ class ChildTool(BaseTool):
args_schema: Optional[Type[BaseModel]] = None args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments.""" """Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means """Whether to return the tool's output directly.
Setting this to True means
that after the tool is called, the AgentExecutor will stop looping. that after the tool is called, the AgentExecutor will stop looping.
""" """
verbose: bool = False verbose: bool = False
@ -360,13 +373,13 @@ class ChildTool(BaseTool):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Deprecated. Please use callbacks instead.""" """Deprecated. Please use callbacks instead."""
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
"""Optional list of tags associated with the tool. Defaults to None """Optional list of tags associated with the tool. Defaults to None.
These tags will be associated with each call to this tool, These tags will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case. You can use these to eg identify a specific instance of a tool with its use case.
""" """
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the tool. Defaults to None """Optional metadata associated with the tool. Defaults to None.
This metadata will be associated with each call to this tool, This metadata will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case. You can use these to eg identify a specific instance of a tool with its use case.
@ -383,7 +396,7 @@ class ChildTool(BaseTool):
"""Handle the content of the ValidationError thrown.""" """Handle the content of the ValidationError thrown."""
response_format: Literal["content", "content_and_artifact"] = "content" response_format: Literal["content", "content_and_artifact"] = "content"
"""The tool response format. """The tool response format. Defaults to 'content'.
If "content" then the output of the tool is interpreted as the contents of a If "content" then the output of the tool is interpreted as the contents of a
ToolMessage. If "content_and_artifact" then the output is expected to be a ToolMessage. If "content_and_artifact" then the output is expected to be a
@ -414,7 +427,14 @@ class ChildTool(BaseTool):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:
"""The tool's input schema.""" """The tool's input schema.
Args:
config: The configuration for the tool.
Returns:
The input schema for the tool.
"""
if self.args_schema is not None: if self.args_schema is not None:
return self.args_schema return self.args_schema
else: else:
@ -441,7 +461,11 @@ class ChildTool(BaseTool):
# --- Tool --- # --- Tool ---
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]: def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model.""" """Convert tool input to a pydantic model.
Args:
tool_input: The input to the tool.
"""
input_args = self.args_schema input_args = self.args_schema
if isinstance(tool_input, str): if isinstance(tool_input, str):
if input_args is not None: if input_args is not None:
@ -460,7 +484,14 @@ class ChildTool(BaseTool):
@root_validator(pre=True) @root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.""" """Raise deprecation warning if callback_manager is used.
Args:
values: The values to validate.
Returns:
The validated values.
"""
if values.get("callback_manager") is not None: if values.get("callback_manager") is not None:
warnings.warn( warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.", "callback_manager is deprecated. Please use callbacks instead.",
@ -514,7 +545,28 @@ class ChildTool(BaseTool):
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool.""" """Run the tool.
Args:
tool_input: The input to the tool.
verbose: Whether to log the tool's progress. Defaults to None.
start_color: The color to use when starting the tool. Defaults to 'green'.
color: The color to use when ending the tool. Defaults to 'green'.
callbacks: Callbacks to be called during tool execution. Defaults to None.
tags: Optional list of tags associated with the tool. Defaults to None.
metadata: Optional metadata associated with the tool. Defaults to None.
run_name: The name of the run. Defaults to None.
run_id: The id of the run. Defaults to None.
config: The configuration for the tool. Defaults to None.
tool_call_id: The id of the tool call. Defaults to None.
kwargs: Additional arguments to pass to the tool
Returns:
The output of the tool.
Raises:
ToolException: If an error occurs during tool execution.
"""
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, callbacks,
self.callbacks, self.callbacks,
@ -600,7 +652,28 @@ class ChildTool(BaseTool):
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool asynchronously.""" """Run the tool asynchronously.
Args:
tool_input: The input to the tool.
verbose: Whether to log the tool's progress. Defaults to None.
start_color: The color to use when starting the tool. Defaults to 'green'.
color: The color to use when ending the tool. Defaults to 'green'.
callbacks: Callbacks to be called during tool execution. Defaults to None.
tags: Optional list of tags associated with the tool. Defaults to None.
metadata: Optional metadata associated with the tool. Defaults to None.
run_name: The name of the run. Defaults to None.
run_id: The id of the run. Defaults to None.
config: The configuration for the tool. Defaults to None.
tool_call_id: The id of the tool call. Defaults to None.
kwargs: Additional arguments to pass to the tool
Returns:
The output of the tool.
Raises:
ToolException: If an error occurs during tool execution.
"""
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, callbacks,
self.callbacks, self.callbacks,
@ -709,7 +782,11 @@ class Tool(BaseTool):
@property @property
def args(self) -> dict: def args(self) -> dict:
"""The tool's input arguments.""" """The tool's input arguments.
Returns:
The input arguments for the tool.
"""
if self.args_schema is not None: if self.args_schema is not None:
return self.args_schema.schema()["properties"] return self.args_schema.schema()["properties"]
# For backwards compatibility, if the function signature is ambiguous, # For backwards compatibility, if the function signature is ambiguous,
@ -788,7 +865,23 @@ class Tool(BaseTool):
] = None, # This is last for compatibility, but should be after func ] = None, # This is last for compatibility, but should be after func
**kwargs: Any, **kwargs: Any,
) -> Tool: ) -> Tool:
"""Initialize tool from a function.""" """Initialize tool from a function.
Args:
func: The function to create the tool from.
name: The name of the tool.
description: The description of the tool.
return_direct: Whether to return the output directly. Defaults to False.
args_schema: The schema of the tool's input arguments. Defaults to None.
coroutine: The asynchronous version of the function. Defaults to None.
**kwargs: Additional arguments to pass to the tool.
Returns:
The tool.
Raises:
ValueError: If the function is not provided.
"""
if func is None and coroutine is None: if func is None and coroutine is None:
raise ValueError("Function and/or coroutine must be provided") raise ValueError("Function and/or coroutine must be provided")
return cls( return cls(
@ -893,25 +986,34 @@ class StructuredTool(BaseTool):
A classmethod that helps to create a tool from a function. A classmethod that helps to create a tool from a function.
Args: Args:
func: The function from which to create a tool func: The function from which to create a tool.
coroutine: The async function from which to create a tool coroutine: The async function from which to create a tool.
name: The name of the tool. Defaults to the function name name: The name of the tool. Defaults to the function name.
description: The description of the tool. Defaults to the function docstring description: The description of the tool.
return_direct: Whether to return the result directly or as a callback Defaults to the function docstring.
args_schema: The schema of the tool's input arguments return_direct: Whether to return the result directly or as a callback.
infer_schema: Whether to infer the schema from the function's signature Defaults to False.
args_schema: The schema of the tool's input arguments. Defaults to None.
infer_schema: Whether to infer the schema from the function's signature.
Defaults to True.
response_format: The tool response format. If "content" then the output of response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If the tool is interpreted as the contents of a ToolMessage. If
"content_and_artifact" then the output is expected to be a two-tuple "content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage. corresponding to the (content, artifact) of a ToolMessage.
Defaults to "content".
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
to parse parameter descriptions from Google Style function docstrings. to parse parameter descriptions from Google Style function docstrings.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures Defaults to False.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings. whether to raise ValueError on invalid Google Style docstrings.
Defaults to False.
**kwargs: Additional arguments to pass to the tool **kwargs: Additional arguments to pass to the tool
Returns: Returns:
The tool The tool.
Raises:
ValueError: If the function is not provided.
Examples: Examples:
@ -989,19 +1091,27 @@ def tool(
Args: Args:
*args: The arguments to the tool. *args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather return_direct: Whether to return directly from the tool rather
than continuing the agent loop. than continuing the agent loop. Defaults to False.
args_schema: optional argument schema for user to specify args_schema: optional argument schema for user to specify.
Defaults to None.
infer_schema: Whether to infer the schema of the arguments from infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function. accept a dictionary input to its `run()` function.
Defaults to True.
response_format: The tool response format. If "content" then the output of response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If the tool is interpreted as the contents of a ToolMessage. If
"content_and_artifact" then the output is expected to be a two-tuple "content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage. corresponding to the (content, artifact) of a ToolMessage.
Defaults to "content".
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
parse parameter descriptions from Google Style function docstrings. parse parameter descriptions from Google Style function docstrings.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures Defaults to False.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings. whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
Returns:
The tool.
Requires: Requires:
- Function must be of type (str) -> str - Function must be of type (str) -> str
@ -1230,9 +1340,11 @@ def create_retriever_tool(
so should be unique and somewhat descriptive. so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language description: The description for the tool. This will be passed to the language
model, so should be descriptive. model, so should be descriptive.
document_prompt: The prompt to use for the document. Defaults to None.
document_separator: The separator to use between documents. Defaults to "\n\n".
Returns: Returns:
Tool class to pass to an agent Tool class to pass to an agent.
""" """
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial( func = partial(
@ -1262,6 +1374,12 @@ ToolsRenderer = Callable[[List[BaseTool]], str]
def render_text_description(tools: List[BaseTool]) -> str: def render_text_description(tools: List[BaseTool]) -> str:
"""Render the tool name and description in plain text. """Render the tool name and description in plain text.
Args:
tools: The tools to render.
Returns:
The rendered text.
Output will be in the format of: Output will be in the format of:
.. code-block:: markdown .. code-block:: markdown
@ -1284,6 +1402,12 @@ def render_text_description(tools: List[BaseTool]) -> str:
def render_text_description_and_args(tools: List[BaseTool]) -> str: def render_text_description_and_args(tools: List[BaseTool]) -> str:
"""Render the tool name, description, and args in plain text. """Render the tool name, description, and args in plain text.
Args:
tools: The tools to render.
Returns:
The rendered text.
Output will be in the format of: Output will be in the format of:
.. code-block:: markdown .. code-block:: markdown
@ -1444,7 +1568,18 @@ def convert_runnable_to_tool(
description: Optional[str] = None, description: Optional[str] = None,
arg_types: Optional[Dict[str, Type]] = None, arg_types: Optional[Dict[str, Type]] = None,
) -> BaseTool: ) -> BaseTool:
"""Convert a Runnable into a BaseTool.""" """Convert a Runnable into a BaseTool.
Args:
runnable: The runnable to convert.
args_schema: The schema for the tool's input arguments. Defaults to None.
name: The name of the tool. Defaults to None.
description: The description of the tool. Defaults to None.
arg_types: The types of the arguments. Defaults to None.
Returns:
The tool.
"""
if args_schema: if args_schema:
runnable = runnable.with_types(input_type=args_schema) runnable = runnable.with_types(input_type=args_schema)
description = description or _get_description_from_runnable(runnable) description = description or _get_description_from_runnable(runnable)