core[patch]: callbacks docstrings (#23375)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-06-26 14:11:06 -07:00 committed by GitHub
parent 1141b08eb8
commit 2a5d59b3d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 707 additions and 98 deletions

View File

@ -6,6 +6,7 @@
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
""" """
from langchain_core.callbacks.base import ( from langchain_core.callbacks.base import (
AsyncCallbackHandler, AsyncCallbackHandler,
BaseCallbackHandler, BaseCallbackHandler,

View File

@ -1,4 +1,5 @@
"""Base callback handler that can be used to handle callbacks in langchain.""" """Base callback handler for LangChain."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
@ -54,7 +55,10 @@ class LLMManagerMixin:
Args: Args:
token (str): The new token. token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information. containing content and other information.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
""" """
def on_llm_end( def on_llm_end(
@ -65,7 +69,14 @@ class LLMManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when LLM ends running.""" """Run when LLM ends running.
Args:
response (LLMResult): The response which was generated.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
def on_llm_error( def on_llm_error(
self, self,
@ -76,11 +87,12 @@ class LLMManagerMixin:
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when LLM errors. """Run when LLM errors.
Args: Args:
error (BaseException): The error that occurred. error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments. kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
""" """
@ -95,7 +107,13 @@ class ChainManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when chain ends running.""" """Run when chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
def on_chain_error( def on_chain_error(
self, self,
@ -105,7 +123,13 @@ class ChainManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when chain errors.""" """Run when chain errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
def on_agent_action( def on_agent_action(
self, self,
@ -115,7 +139,13 @@ class ChainManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run on agent action.""" """Run on agent action.
Args:
action (AgentAction): The agent action.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
def on_agent_finish( def on_agent_finish(
self, self,
@ -125,7 +155,13 @@ class ChainManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run on agent end.""" """Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
class ToolManagerMixin: class ToolManagerMixin:
@ -139,7 +175,13 @@ class ToolManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when tool ends running.""" """Run when the tool ends running.
Args:
output (Any): The output of the tool.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
def on_tool_error( def on_tool_error(
self, self,
@ -149,7 +191,13 @@ class ToolManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when tool errors.""" """Run when tool errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
class CallbackManagerMixin: class CallbackManagerMixin:
@ -171,6 +219,15 @@ class CallbackManagerMixin:
**ATTENTION**: This method is called for non-chat models (regular LLMs). If **ATTENTION**: This method is called for non-chat models (regular LLMs). If
you're implementing a handler for a chat model, you're implementing a handler for a chat model,
you should use on_chat_model_start instead. you should use on_chat_model_start instead.
Args:
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The prompts.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
""" """
def on_chat_model_start( def on_chat_model_start(
@ -188,6 +245,15 @@ class CallbackManagerMixin:
**ATTENTION**: This method is called for chat models. If you're implementing **ATTENTION**: This method is called for chat models. If you're implementing
a handler for a non-chat model, you should use on_llm_start instead. a handler for a non-chat model, you should use on_llm_start instead.
Args:
serialized (Dict[str, Any]): The serialized chat model.
messages (List[List[BaseMessage]]): The messages.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
""" """
# NotImplementedError is thrown intentionally # NotImplementedError is thrown intentionally
# Callback handler will fall back to on_llm_start if this is exception is thrown # Callback handler will fall back to on_llm_start if this is exception is thrown
@ -206,7 +272,17 @@ class CallbackManagerMixin:
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when Retriever starts running.""" """Run when the Retriever starts running.
Args:
serialized (Dict[str, Any]): The serialized Retriever.
query (str): The query.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
def on_chain_start( def on_chain_start(
self, self,
@ -219,7 +295,17 @@ class CallbackManagerMixin:
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when chain starts running.""" """Run when a chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
def on_tool_start( def on_tool_start(
self, self,
@ -233,7 +319,18 @@ class CallbackManagerMixin:
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when tool starts running.""" """Run when the tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input string.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
inputs (Optional[Dict[str, Any]]): The inputs.
kwargs (Any): Additional keyword arguments.
"""
class RunManagerMixin: class RunManagerMixin:
@ -247,7 +344,14 @@ class RunManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run on arbitrary text.""" """Run on an arbitrary text.
Args:
text (str): The text.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
def on_retry( def on_retry(
self, self,
@ -257,7 +361,14 @@ class RunManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run on a retry event.""" """Run on a retry event.
Args:
retry_state (RetryCallState): The retry state.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
class BaseCallbackHandler( class BaseCallbackHandler(
@ -268,11 +379,13 @@ class BaseCallbackHandler(
CallbackManagerMixin, CallbackManagerMixin,
RunManagerMixin, RunManagerMixin,
): ):
"""Base callback handler that handles callbacks from LangChain.""" """Base callback handler for LangChain."""
raise_error: bool = False raise_error: bool = False
"""Whether to raise an error if an exception occurs."""
run_inline: bool = False run_inline: bool = False
"""Whether to run the callback inline."""
@property @property
def ignore_llm(self) -> bool: def ignore_llm(self) -> bool:
@ -306,7 +419,7 @@ class BaseCallbackHandler(
class AsyncCallbackHandler(BaseCallbackHandler): class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that handles callbacks from LangChain.""" """Async callback handler for LangChain."""
async def on_llm_start( async def on_llm_start(
self, self,
@ -324,6 +437,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
**ATTENTION**: This method is called for non-chat models (regular LLMs). If **ATTENTION**: This method is called for non-chat models (regular LLMs). If
you're implementing a handler for a chat model, you're implementing a handler for a chat model,
you should use on_chat_model_start instead. you should use on_chat_model_start instead.
Args:
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The prompts.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
""" """
async def on_chat_model_start( async def on_chat_model_start(
@ -341,6 +463,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
**ATTENTION**: This method is called for chat models. If you're implementing **ATTENTION**: This method is called for chat models. If you're implementing
a handler for a non-chat model, you should use on_llm_start instead. a handler for a non-chat model, you should use on_llm_start instead.
Args:
serialized (Dict[str, Any]): The serialized chat model.
messages (List[List[BaseMessage]]): The messages.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
""" """
# NotImplementedError is thrown intentionally # NotImplementedError is thrown intentionally
# Callback handler will fall back to on_llm_start if this is exception is thrown # Callback handler will fall back to on_llm_start if this is exception is thrown
@ -358,7 +489,17 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on new LLM token. Only available when streaming is enabled.""" """Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_llm_end( async def on_llm_end(
self, self,
@ -369,7 +510,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM ends running.""" """Run when LLM ends running.
Args:
response (LLMResult): The response which was generated.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_llm_error( async def on_llm_error(
self, self,
@ -384,6 +533,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
Args: Args:
error: The error that occurred. error: The error that occurred.
run_id: The run ID. This is the ID of the current run.
parent_run_id: The parent run ID. This is the ID of the parent run.
tags: The tags.
kwargs (Any): Additional keyword arguments. kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before - response (LLMResult): The response which was generated before
the error occurred. the error occurred.
@ -400,7 +552,17 @@ class AsyncCallbackHandler(BaseCallbackHandler):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when chain starts running.""" """Run when a chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
async def on_chain_end( async def on_chain_end(
self, self,
@ -411,7 +573,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when chain ends running.""" """Run when a chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_chain_error( async def on_chain_error(
self, self,
@ -422,7 +592,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when chain errors.""" """Run when chain errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_tool_start( async def on_tool_start(
self, self,
@ -436,7 +614,18 @@ class AsyncCallbackHandler(BaseCallbackHandler):
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when the tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input string.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
inputs (Optional[Dict[str, Any]]): The inputs.
kwargs (Any): Additional keyword arguments.
"""
async def on_tool_end( async def on_tool_end(
self, self,
@ -447,7 +636,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when tool ends running.""" """Run when the tool ends running.
Args:
output (Any): The output of the tool.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_tool_error( async def on_tool_error(
self, self,
@ -458,7 +655,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when tool errors.""" """Run when tool errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_text( async def on_text(
self, self,
@ -469,7 +674,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on arbitrary text.""" """Run on an arbitrary text.
Args:
text (str): The text.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_retry( async def on_retry(
self, self,
@ -479,7 +692,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run on a retry event.""" """Run on a retry event.
Args:
retry_state (RetryCallState): The retry state.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
async def on_agent_action( async def on_agent_action(
self, self,
@ -490,7 +710,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on agent action.""" """Run on agent action.
Args:
action (AgentAction): The agent action.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_agent_finish( async def on_agent_finish(
self, self,
@ -501,7 +729,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on agent end.""" """Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_retriever_start( async def on_retriever_start(
self, self,
@ -514,7 +750,17 @@ class AsyncCallbackHandler(BaseCallbackHandler):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on retriever start.""" """Run on the retriever start.
Args:
serialized (Dict[str, Any]): The serialized retriever.
query (str): The query.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
async def on_retriever_end( async def on_retriever_end(
self, self,
@ -525,7 +771,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on retriever end.""" """Run on the retriever end.
Args:
documents (Sequence[Document]): The documents retrieved.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments."""
async def on_retriever_error( async def on_retriever_error(
self, self,
@ -536,14 +789,22 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on retriever error.""" """Run on retriever error.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
T = TypeVar("T", bound="BaseCallbackManager") T = TypeVar("T", bound="BaseCallbackManager")
class BaseCallbackManager(CallbackManagerMixin): class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that handles callbacks from LangChain.""" """Base callback manager for LangChain."""
def __init__( def __init__(
self, self,
@ -556,7 +817,18 @@ class BaseCallbackManager(CallbackManagerMixin):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None, inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Initialize callback manager.""" """Initialize callback manager.
Args:
handlers (List[BaseCallbackHandler]): The handlers.
inheritable_handlers (Optional[List[BaseCallbackHandler]]):
The inheritable handlers. Default is None.
parent_run_id (Optional[UUID]): The parent run ID. Default is None.
tags (Optional[List[str]]): The tags. Default is None.
inheritable_tags (Optional[List[str]]): The inheritable tags.
Default is None.
metadata (Optional[Dict[str, Any]]): The metadata. Default is None.
"""
self.handlers: List[BaseCallbackHandler] = handlers self.handlers: List[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = ( self.inheritable_handlers: List[BaseCallbackHandler] = (
inheritable_handlers or [] inheritable_handlers or []
@ -585,31 +857,56 @@ class BaseCallbackManager(CallbackManagerMixin):
return False return False
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Add a handler to the callback manager.""" """Add a handler to the callback manager.
Args:
handler (BaseCallbackHandler): The handler to add.
inherit (bool): Whether to inherit the handler. Default is True.
"""
if handler not in self.handlers: if handler not in self.handlers:
self.handlers.append(handler) self.handlers.append(handler)
if inherit and handler not in self.inheritable_handlers: if inherit and handler not in self.inheritable_handlers:
self.inheritable_handlers.append(handler) self.inheritable_handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None: def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager.""" """Remove a handler from the callback manager.
Args:
handler (BaseCallbackHandler): The handler to remove.
"""
self.handlers.remove(handler) self.handlers.remove(handler)
self.inheritable_handlers.remove(handler) self.inheritable_handlers.remove(handler)
def set_handlers( def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True self, handlers: List[BaseCallbackHandler], inherit: bool = True
) -> None: ) -> None:
"""Set handlers as the only handlers on the callback manager.""" """Set handlers as the only handlers on the callback manager.
Args:
handlers (List[BaseCallbackHandler]): The handlers to set.
inherit (bool): Whether to inherit the handlers. Default is True.
"""
self.handlers = [] self.handlers = []
self.inheritable_handlers = [] self.inheritable_handlers = []
for handler in handlers: for handler in handlers:
self.add_handler(handler, inherit=inherit) self.add_handler(handler, inherit=inherit)
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager.""" """Set handler as the only handler on the callback manager.
Args:
handler (BaseCallbackHandler): The handler to set.
inherit (bool): Whether to inherit the handler. Default is True.
"""
self.set_handlers([handler], inherit=inherit) self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None: def add_tags(self, tags: List[str], inherit: bool = True) -> None:
"""Add tags to the callback manager.
Args:
tags (List[str]): The tags to add.
inherit (bool): Whether to inherit the tags. Default is True.
"""
for tag in tags: for tag in tags:
if tag in self.tags: if tag in self.tags:
self.remove_tags([tag]) self.remove_tags([tag])
@ -618,16 +915,32 @@ class BaseCallbackManager(CallbackManagerMixin):
self.inheritable_tags.extend(tags) self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None: def remove_tags(self, tags: List[str]) -> None:
"""Remove tags from the callback manager.
Args:
tags (List[str]): The tags to remove.
"""
for tag in tags: for tag in tags:
self.tags.remove(tag) self.tags.remove(tag)
self.inheritable_tags.remove(tag) self.inheritable_tags.remove(tag)
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
"""Add metadata to the callback manager.
Args:
metadata (Dict[str, Any]): The metadata to add.
inherit (bool): Whether to inherit the metadata. Default is True.
"""
self.metadata.update(metadata) self.metadata.update(metadata)
if inherit: if inherit:
self.inheritable_metadata.update(metadata) self.inheritable_metadata.update(metadata)
def remove_metadata(self, keys: List[str]) -> None: def remove_metadata(self, keys: List[str]) -> None:
"""Remove metadata from the callback manager.
Args:
keys (List[str]): The keys to remove.
"""
for key in keys: for key in keys:
self.metadata.pop(key) self.metadata.pop(key)
self.inheritable_metadata.pop(key) self.inheritable_metadata.pop(key)

View File

@ -10,12 +10,23 @@ from langchain_core.utils.input import print_text
class FileCallbackHandler(BaseCallbackHandler): class FileCallbackHandler(BaseCallbackHandler):
"""Callback Handler that writes to a file.""" """Callback Handler that writes to a file.
Parameters:
file: The file to write to.
color: The color to use for the text.
"""
def __init__( def __init__(
self, filename: str, mode: str = "a", color: Optional[str] = None self, filename: str, mode: str = "a", color: Optional[str] = None
) -> None: ) -> None:
"""Initialize callback handler.""" """Initialize callback handler.
Args:
filename: The filename to write to.
mode: The mode to open the file in. Defaults to "a".
color: The color to use for the text. Defaults to None.
"""
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) self.file = cast(TextIO, open(filename, mode, encoding="utf-8"))
self.color = color self.color = color
@ -26,7 +37,13 @@ class FileCallbackHandler(BaseCallbackHandler):
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
**kwargs (Any): Additional keyword arguments.
"""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1]) class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print_text( print_text(
f"\n\n\033[1m> Entering new {class_name} chain...\033[0m", f"\n\n\033[1m> Entering new {class_name} chain...\033[0m",
@ -35,13 +52,25 @@ class FileCallbackHandler(BaseCallbackHandler):
) )
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file) print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file)
def on_agent_action( def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any: ) -> Any:
"""Run on agent action.""" """Run on agent action.
Args:
action (AgentAction): The agent action.
color (Optional[str], optional): The color to use for the text.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
print_text(action.log, color=color or self.color, file=self.file) print_text(action.log, color=color or self.color, file=self.file)
def on_tool_end( def on_tool_end(
@ -52,7 +81,18 @@ class FileCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None, llm_prefix: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""If not the final action, print out observation.""" """If not the final action, print out observation.
Args:
output (str): The output to print.
color (Optional[str], optional): The color to use for the text.
Defaults to None.
observation_prefix (Optional[str], optional): The observation prefix.
Defaults to None.
llm_prefix (Optional[str], optional): The LLM prefix.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if observation_prefix is not None: if observation_prefix is not None:
print_text(f"\n{observation_prefix}", file=self.file) print_text(f"\n{observation_prefix}", file=self.file)
print_text(output, color=color or self.color, file=self.file) print_text(output, color=color or self.color, file=self.file)
@ -62,11 +102,26 @@ class FileCallbackHandler(BaseCallbackHandler):
def on_text( def on_text(
self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any
) -> None: ) -> None:
"""Run when agent ends.""" """Run when the agent ends.
Args:
text (str): The text to print.
color (Optional[str], optional): The color to use for the text.
Defaults to None.
end (str, optional): The end character. Defaults to "".
**kwargs (Any): Additional keyword arguments.
"""
print_text(text, color=color or self.color, end=end, file=self.file) print_text(text, color=color or self.color, end=end, file=self.file)
def on_agent_finish( def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None: ) -> None:
"""Run on agent end.""" """Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
color (Optional[str], optional): The color to use for the text.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
print_text(finish.log, color=color or self.color, end="\n", file=self.file) print_text(finish.log, color=color or self.color, end="\n", file=self.file)

View File

@ -77,7 +77,9 @@ def trace_as_chain_group(
Args: Args:
group_name (str): The name of the chain group. group_name (str): The name of the chain group.
callback_manager (CallbackManager, optional): The callback manager to use. callback_manager (CallbackManager, optional): The callback manager to use.
Defaults to None.
inputs (Dict[str, Any], optional): The inputs to the chain group. inputs (Dict[str, Any], optional): The inputs to the chain group.
Defaults to None.
project_name (str, optional): The name of the project. project_name (str, optional): The name of the project.
Defaults to None. Defaults to None.
example_id (str or UUID, optional): The ID of the example. example_id (str or UUID, optional): The ID of the example.
@ -155,7 +157,9 @@ async def atrace_as_chain_group(
Args: Args:
group_name (str): The name of the chain group. group_name (str): The name of the chain group.
callback_manager (AsyncCallbackManager, optional): The async callback manager to use, callback_manager (AsyncCallbackManager, optional): The async callback manager to use,
which manages tracing and other callback behavior. which manages tracing and other callback behavior. Defaults to None.
inputs (Dict[str, Any], optional): The inputs to the chain group.
Defaults to None.
project_name (str, optional): The name of the project. project_name (str, optional): The name of the project.
Defaults to None. Defaults to None.
example_id (str or UUID, optional): The ID of the example. example_id (str or UUID, optional): The ID of the example.
@ -218,7 +222,13 @@ Func = TypeVar("Func", bound=Callable)
def shielded(func: Func) -> Func: def shielded(func: Func) -> Func:
""" """
Makes so an awaitable method is always shielded from cancellation Makes so an awaitable method is always shielded from cancellation.
Args:
func (Callable): The function to shield.
Returns:
Callable: The shielded function
""" """
@functools.wraps(func) @functools.wraps(func)
@ -237,14 +247,14 @@ def handle_event(
) -> None: ) -> None:
"""Generic event handler for CallbackManager. """Generic event handler for CallbackManager.
Note: This function is used by langserve to handle events. Note: This function is used by LangServe to handle events.
Args: Args:
handlers: The list of handlers that will handle the event handlers: The list of handlers that will handle the event.
event_name: The name of the event (e.g., "on_llm_start") event_name: The name of the event (e.g., "on_llm_start").
ignore_condition_name: Name of the attribute defined on handler ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event that if True will cause the handler to be skipped for the given event.
*args: The arguments to pass to the event handler *args: The arguments to pass to the event handler.
**kwargs: The keyword arguments to pass to the event handler **kwargs: The keyword arguments to pass to the event handler
""" """
coros: List[Coroutine[Any, Any, Any]] = [] coros: List[Coroutine[Any, Any, Any]] = []
@ -394,17 +404,17 @@ async def ahandle_event(
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Generic event handler for AsyncCallbackManager. """Async generic event handler for AsyncCallbackManager.
Note: This function is used by langserve to handle events. Note: This function is used by LangServe to handle events.
Args: Args:
handlers: The list of handlers that will handle the event handlers: The list of handlers that will handle the event.
event_name: The name of the event (e.g., "on_llm_start") event_name: The name of the event (e.g., "on_llm_start").
ignore_condition_name: Name of the attribute defined on handler ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event that if True will cause the handler to be skipped for the given event.
*args: The arguments to pass to the event handler *args: The arguments to pass to the event handler.
**kwargs: The keyword arguments to pass to the event handler **kwargs: The keyword arguments to pass to the event handler.
""" """
for handler in [h for h in handlers if h.run_inline]: for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler( await _ahandle_event_for_handler(
@ -452,10 +462,13 @@ class BaseRunManager(RunManagerMixin):
The list of inheritable handlers. The list of inheritable handlers.
parent_run_id (UUID, optional): The ID of the parent run. parent_run_id (UUID, optional): The ID of the parent run.
Defaults to None. Defaults to None.
tags (Optional[List[str]]): The list of tags. tags (Optional[List[str]]): The list of tags. Defaults to None.
inheritable_tags (Optional[List[str]]): The list of inheritable tags. inheritable_tags (Optional[List[str]]): The list of inheritable tags.
Defaults to None.
metadata (Optional[Dict[str, Any]]): The metadata. metadata (Optional[Dict[str, Any]]): The metadata.
Defaults to None.
inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata. inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata.
Defaults to None.
""" """
self.run_id = run_id self.run_id = run_id
self.handlers = handlers self.handlers = handlers
@ -492,10 +505,11 @@ class RunManager(BaseRunManager):
text: str, text: str,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when text is received. """Run when a text is received.
Args: Args:
text (str): The received text. text (str): The received text.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
Any: The result of the callback. Any: The result of the callback.
@ -516,6 +530,12 @@ class RunManager(BaseRunManager):
retry_state: RetryCallState, retry_state: RetryCallState,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when a retry is received.
Args:
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
handle_event( handle_event(
self.handlers, self.handlers,
"on_retry", "on_retry",
@ -566,10 +586,11 @@ class AsyncRunManager(BaseRunManager, ABC):
text: str, text: str,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when text is received. """Run when a text is received.
Args: Args:
text (str): The received text. text (str): The received text.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
Any: The result of the callback. Any: The result of the callback.
@ -590,6 +611,12 @@ class AsyncRunManager(BaseRunManager, ABC):
retry_state: RetryCallState, retry_state: RetryCallState,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Async run when a retry is received.
Args:
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
"on_retry", "on_retry",
@ -638,6 +665,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args: Args:
token (str): The new token. token (str): The new token.
chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
""" """
handle_event( handle_event(
self.handlers, self.handlers,
@ -656,6 +686,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args: Args:
response (LLMResult): The LLM result. response (LLMResult): The LLM result.
**kwargs (Any): Additional keyword arguments.
""" """
handle_event( handle_event(
self.handlers, self.handlers,
@ -725,6 +756,9 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args: Args:
token (str): The new token. token (str): The new token.
chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
""" """
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
@ -744,6 +778,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args: Args:
response (LLMResult): The LLM result. response (LLMResult): The LLM result.
**kwargs (Any): Additional keyword arguments.
""" """
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
@ -793,6 +828,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args: Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain. outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
""" """
handle_event( handle_event(
self.handlers, self.handlers,
@ -814,6 +850,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args: Args:
error (Exception or KeyboardInterrupt): The error. error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
""" """
handle_event( handle_event(
self.handlers, self.handlers,
@ -831,6 +868,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args: Args:
action (AgentAction): The agent action. action (AgentAction): The agent action.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
Any: The result of the callback. Any: The result of the callback.
@ -851,6 +889,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args: Args:
finish (AgentFinish): The agent finish. finish (AgentFinish): The agent finish.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
Any: The result of the callback. Any: The result of the callback.
@ -891,10 +930,11 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
async def on_chain_end( async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None: ) -> None:
"""Run when chain ends running. """Run when a chain ends running.
Args: Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain. outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
""" """
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
@ -917,6 +957,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args: Args:
error (Exception or KeyboardInterrupt): The error. error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
""" """
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
@ -935,6 +976,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args: Args:
action (AgentAction): The agent action. action (AgentAction): The agent action.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
Any: The result of the callback. Any: The result of the callback.
@ -956,6 +998,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args: Args:
finish (AgentFinish): The agent finish. finish (AgentFinish): The agent finish.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
Any: The result of the callback. Any: The result of the callback.
@ -980,10 +1023,11 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
output: Any, output: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when tool ends running. """Run when the tool ends running.
Args: Args:
output (Any): The output of the tool. output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
""" """
handle_event( handle_event(
self.handlers, self.handlers,
@ -1005,6 +1049,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
Args: Args:
error (Exception or KeyboardInterrupt): The error. error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
""" """
handle_event( handle_event(
self.handlers, self.handlers,
@ -1040,10 +1085,11 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
@shielded @shielded
async def on_tool_end(self, output: Any, **kwargs: Any) -> None: async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running. """Async run when the tool ends running.
Args: Args:
output (Any): The output of the tool. output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
""" """
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
@ -1066,6 +1112,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
Args: Args:
error (Exception or KeyboardInterrupt): The error. error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
""" """
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
@ -1087,7 +1134,12 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
documents: Sequence[Document], documents: Sequence[Document],
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when retriever ends running.""" """Run when retriever ends running.
Args:
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
handle_event( handle_event(
self.handlers, self.handlers,
"on_retriever_end", "on_retriever_end",
@ -1104,7 +1156,12 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
error: BaseException, error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when retriever errors.""" """Run when retriever errors.
Args:
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
handle_event( handle_event(
self.handlers, self.handlers,
"on_retriever_error", "on_retriever_error",
@ -1144,7 +1201,12 @@ class AsyncCallbackManagerForRetrieverRun(
async def on_retriever_end( async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any self, documents: Sequence[Document], **kwargs: Any
) -> None: ) -> None:
"""Run when retriever ends running.""" """Run when the retriever ends running.
Args:
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
"on_retriever_end", "on_retriever_end",
@ -1162,7 +1224,12 @@ class AsyncCallbackManagerForRetrieverRun(
error: BaseException, error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when retriever errors.""" """Run when retriever errors.
Args:
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,
"on_retriever_error", "on_retriever_error",
@ -1176,7 +1243,7 @@ class AsyncCallbackManagerForRetrieverRun(
class CallbackManager(BaseCallbackManager): class CallbackManager(BaseCallbackManager):
"""Callback manager that handles callbacks from LangChain.""" """Callback manager for LangChain."""
def on_llm_start( def on_llm_start(
self, self,
@ -1191,6 +1258,7 @@ class CallbackManager(BaseCallbackManager):
serialized (Dict[str, Any]): The serialized LLM. serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The list of prompts. prompts (List[str]): The list of prompts.
run_id (UUID, optional): The ID of the run. Defaults to None. run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
List[CallbackManagerForLLMRun]: A callback manager for each List[CallbackManagerForLLMRun]: A callback manager for each
@ -1241,6 +1309,7 @@ class CallbackManager(BaseCallbackManager):
serialized (Dict[str, Any]): The serialized LLM. serialized (Dict[str, Any]): The serialized LLM.
messages (List[List[BaseMessage]]): The list of messages. messages (List[List[BaseMessage]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None. run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
List[CallbackManagerForLLMRun]: A callback manager for each List[CallbackManagerForLLMRun]: A callback manager for each
@ -1295,6 +1364,7 @@ class CallbackManager(BaseCallbackManager):
serialized (Dict[str, Any]): The serialized chain. serialized (Dict[str, Any]): The serialized chain.
inputs (Union[Dict[str, Any], Any]): The inputs to the chain. inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
run_id (UUID, optional): The ID of the run. Defaults to None. run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
CallbackManagerForChainRun: The callback manager for the chain run. CallbackManagerForChainRun: The callback manager for the chain run.
@ -1347,6 +1417,7 @@ class CallbackManager(BaseCallbackManager):
input is needed. input is needed.
If provided, the inputs are expected to be formatted as a dict. If provided, the inputs are expected to be formatted as a dict.
The keys will correspond to the named-arguments in the tool. The keys will correspond to the named-arguments in the tool.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
CallbackManagerForToolRun: The callback manager for the tool run. CallbackManagerForToolRun: The callback manager for the tool run.
@ -1387,7 +1458,15 @@ class CallbackManager(BaseCallbackManager):
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> CallbackManagerForRetrieverRun: ) -> CallbackManagerForRetrieverRun:
"""Run when retriever starts running.""" """Run when the retriever starts running.
Args:
serialized (Dict[str, Any]): The serialized retriever.
query (str): The query.
run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if run_id is None: if run_id is None:
run_id = uuid.uuid4() run_id = uuid.uuid4()
@ -1470,6 +1549,16 @@ class CallbackManagerForChainGroup(CallbackManager):
parent_run_manager: CallbackManagerForChainRun, parent_run_manager: CallbackManagerForChainRun,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize the callback manager.
Args:
handlers (List[BaseCallbackHandler]): The list of handlers.
inheritable_handlers (Optional[List[BaseCallbackHandler]]): The list of
inheritable handlers. Defaults to None.
parent_run_id (Optional[UUID]): The ID of the parent run. Defaults to None.
parent_run_manager (CallbackManagerForChainRun): The parent run manager.
**kwargs (Any): Additional keyword arguments.
"""
super().__init__( super().__init__(
handlers, handlers,
inheritable_handlers, inheritable_handlers,
@ -1480,6 +1569,7 @@ class CallbackManagerForChainGroup(CallbackManager):
self.ended = False self.ended = False
def copy(self) -> CallbackManagerForChainGroup: def copy(self) -> CallbackManagerForChainGroup:
"""Copy the callback manager."""
return self.__class__( return self.__class__(
handlers=self.handlers, handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers, inheritable_handlers=self.inheritable_handlers,
@ -1496,6 +1586,7 @@ class CallbackManagerForChainGroup(CallbackManager):
Args: Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain. outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
""" """
self.ended = True self.ended = True
return self.parent_run_manager.on_chain_end(outputs, **kwargs) return self.parent_run_manager.on_chain_end(outputs, **kwargs)
@ -1509,6 +1600,7 @@ class CallbackManagerForChainGroup(CallbackManager):
Args: Args:
error (Exception or KeyboardInterrupt): The error. error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
""" """
self.ended = True self.ended = True
return self.parent_run_manager.on_chain_error(error, **kwargs) return self.parent_run_manager.on_chain_error(error, **kwargs)
@ -1535,6 +1627,7 @@ class AsyncCallbackManager(BaseCallbackManager):
serialized (Dict[str, Any]): The serialized LLM. serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The list of prompts. prompts (List[str]): The list of prompts.
run_id (UUID, optional): The ID of the run. Defaults to None. run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
List[AsyncCallbackManagerForLLMRun]: The list of async List[AsyncCallbackManagerForLLMRun]: The list of async
@ -1591,12 +1684,13 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]: ) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running. """Async run when LLM starts running.
Args: Args:
serialized (Dict[str, Any]): The serialized LLM. serialized (Dict[str, Any]): The serialized LLM.
messages (List[List[BaseMessage]]): The list of messages. messages (List[List[BaseMessage]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None. run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
List[AsyncCallbackManagerForLLMRun]: The list of List[AsyncCallbackManagerForLLMRun]: The list of
@ -1651,12 +1745,13 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncCallbackManagerForChainRun: ) -> AsyncCallbackManagerForChainRun:
"""Run when chain starts running. """Async run when chain starts running.
Args: Args:
serialized (Dict[str, Any]): The serialized chain. serialized (Dict[str, Any]): The serialized chain.
inputs (Union[Dict[str, Any], Any]): The inputs to the chain. inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
run_id (UUID, optional): The ID of the run. Defaults to None. run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
AsyncCallbackManagerForChainRun: The async callback manager AsyncCallbackManagerForChainRun: The async callback manager
@ -1697,7 +1792,7 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncCallbackManagerForToolRun: ) -> AsyncCallbackManagerForToolRun:
"""Run when tool starts running. """Run when the tool starts running.
Args: Args:
serialized (Dict[str, Any]): The serialized tool. serialized (Dict[str, Any]): The serialized tool.
@ -1705,6 +1800,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id (UUID, optional): The ID of the run. Defaults to None. run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run. parent_run_id (UUID, optional): The ID of the parent run.
Defaults to None. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns: Returns:
AsyncCallbackManagerForToolRun: The async callback manager AsyncCallbackManagerForToolRun: The async callback manager
@ -1745,7 +1841,19 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncCallbackManagerForRetrieverRun: ) -> AsyncCallbackManagerForRetrieverRun:
"""Run when retriever starts running.""" """Run when the retriever starts running.
Args:
serialized (Dict[str, Any]): The serialized retriever.
query (str): The query.
run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
AsyncCallbackManagerForRetrieverRun: The async callback manager
for the retriever run.
"""
if run_id is None: if run_id is None:
run_id = uuid.uuid4() run_id = uuid.uuid4()
@ -1828,6 +1936,17 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
parent_run_manager: AsyncCallbackManagerForChainRun, parent_run_manager: AsyncCallbackManagerForChainRun,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize the async callback manager.
Args:
handlers (List[BaseCallbackHandler]): The list of handlers.
inheritable_handlers (Optional[List[BaseCallbackHandler]]): The list of
inheritable handlers. Defaults to None.
parent_run_id (Optional[UUID]): The ID of the parent run. Defaults to None.
parent_run_manager (AsyncCallbackManagerForChainRun):
The parent run manager.
**kwargs (Any): Additional keyword arguments.
"""
super().__init__( super().__init__(
handlers, handlers,
inheritable_handlers, inheritable_handlers,
@ -1838,6 +1957,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
self.ended = False self.ended = False
def copy(self) -> AsyncCallbackManagerForChainGroup: def copy(self) -> AsyncCallbackManagerForChainGroup:
"""Copy the async callback manager."""
return self.__class__( return self.__class__(
handlers=self.handlers, handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers, inheritable_handlers=self.inheritable_handlers,
@ -1856,6 +1976,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
Args: Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain. outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
""" """
self.ended = True self.ended = True
await self.parent_run_manager.on_chain_end(outputs, **kwargs) await self.parent_run_manager.on_chain_end(outputs, **kwargs)
@ -1869,6 +1990,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
Args: Args:
error (Exception or KeyboardInterrupt): The error. error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
""" """
self.ended = True self.ended = True
await self.parent_run_manager.on_chain_error(error, **kwargs) await self.parent_run_manager.on_chain_error(error, **kwargs)

View File

@ -15,24 +15,45 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
def __init__(self, color: Optional[str] = None) -> None: def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler.""" """Initialize callback handler.
Args:
color: The color to use for the text. Defaults to None.
"""
self.color = color self.color = color
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
**kwargs (Any): Additional keyword arguments.
"""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1]) class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201 print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
print("\n\033[1m> Finished chain.\033[0m") # noqa: T201 print("\n\033[1m> Finished chain.\033[0m") # noqa: T201
def on_agent_action( def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any: ) -> Any:
"""Run on agent action.""" """Run on agent action.
Args:
action (AgentAction): The agent action.
color (Optional[str]): The color to use for the text. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
print_text(action.log, color=color or self.color) print_text(action.log, color=color or self.color)
def on_tool_end( def on_tool_end(
@ -43,7 +64,16 @@ class StdOutCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None, llm_prefix: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""If not the final action, print out observation.""" """If not the final action, print out observation.
Args:
output (Any): The output to print.
color (Optional[str]): The color to use for the text. Defaults to None.
observation_prefix (Optional[str]): The observation prefix.
Defaults to None.
llm_prefix (Optional[str]): The LLM prefix. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
output = str(output) output = str(output)
if observation_prefix is not None: if observation_prefix is not None:
print_text(f"\n{observation_prefix}") print_text(f"\n{observation_prefix}")
@ -58,11 +88,24 @@ class StdOutCallbackHandler(BaseCallbackHandler):
end: str = "", end: str = "",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when agent ends.""" """Run when the agent ends.
Args:
text (str): The text to print.
color (Optional[str]): The color to use for the text. Defaults to None.
end (str): The end character to use. Defaults to "".
**kwargs (Any): Additional keyword arguments.
"""
print_text(text, color=color or self.color, end=end) print_text(text, color=color or self.color, end=end)
def on_agent_finish( def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None: ) -> None:
"""Run on agent end.""" """Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
color (Optional[str]): The color to use for the text. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
print_text(finish.log, color=color or self.color, end="\n") print_text(finish.log, color=color or self.color, end="\n")

View File

@ -1,4 +1,5 @@
"""Callback Handler streams to stdout on new llm token.""" """Callback Handler streams to stdout on new llm token."""
from __future__ import annotations from __future__ import annotations
import sys import sys
@ -18,7 +19,13 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
"""Run when LLM starts running.""" """Run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The prompts to run.
**kwargs (Any): Additional keyword arguments.
"""
def on_chat_model_start( def on_chat_model_start(
self, self,
@ -26,47 +33,115 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM starts running.""" """Run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
messages (List[List[BaseMessage]]): The messages to run.
**kwargs (Any): Additional keyword arguments.
"""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled.""" """Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
**kwargs (Any): Additional keyword arguments.
"""
sys.stdout.write(token) sys.stdout.write(token)
sys.stdout.flush() sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running.""" """Run when LLM ends running.
Args:
response (LLMResult): The response from the LLM.
**kwargs (Any): Additional keyword arguments.
"""
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors.""" """Run when LLM errors.
Args:
error (BaseException): The error that occurred.
**kwargs (Any): Additional keyword arguments.
"""
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Run when chain starts running.""" """Run when a chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
**kwargs (Any): Additional keyword arguments.
"""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running.""" """Run when a chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors.""" """Run when chain errors.
Args:
error (BaseException): The error that occurred.
**kwargs (Any): Additional keyword arguments.
"""
def on_tool_start( def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when the tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input string.
**kwargs (Any): Additional keyword arguments.
"""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action.""" """Run on agent action.
Args:
action (AgentAction): The agent action.
**kwargs (Any): Additional keyword arguments.
"""
pass pass
def on_tool_end(self, output: Any, **kwargs: Any) -> None: def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running.""" """Run when tool ends running.
Args:
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors.""" """Run when tool errors.
Args:
error (BaseException): The error that occurred.
**kwargs (Any): Additional keyword arguments.
"""
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text.""" """Run on an arbitrary text.
Args:
text (str): The text to print.
**kwargs (Any): Additional keyword arguments.
"""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end.""" """Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
**kwargs (Any): Additional keyword arguments.
"""