From 2a5d59b3d7ccf45ee548c285d3f3ceab1ef2bc56 Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Wed, 26 Jun 2024 14:11:06 -0700 Subject: [PATCH] core[patch]: `callbacks` docstrings (#23375) Added missed docstrings. Formatted docstrings to the consistent form. --- .../core/langchain_core/callbacks/__init__.py | 1 + libs/core/langchain_core/callbacks/base.py | 391 ++++++++++++++++-- libs/core/langchain_core/callbacks/file.py | 71 +++- libs/core/langchain_core/callbacks/manager.py | 182 ++++++-- libs/core/langchain_core/callbacks/stdout.py | 57 ++- .../callbacks/streaming_stdout.py | 103 ++++- 6 files changed, 707 insertions(+), 98 deletions(-) diff --git a/libs/core/langchain_core/callbacks/__init__.py b/libs/core/langchain_core/callbacks/__init__.py index 65df88d69e5..43e2178a77b 100644 --- a/libs/core/langchain_core/callbacks/__init__.py +++ b/libs/core/langchain_core/callbacks/__init__.py @@ -6,6 +6,7 @@ BaseCallbackHandler --> CallbackHandler # Example: AimCallbackHandler """ + from langchain_core.callbacks.base import ( AsyncCallbackHandler, BaseCallbackHandler, diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 900cb2fcffc..82059d2a477 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -1,4 +1,5 @@ -"""Base callback handler that can be used to handle callbacks in langchain.""" +"""Base callback handler for LangChain.""" + from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union @@ -54,7 +55,10 @@ class LLMManagerMixin: Args: token (str): The new token. chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, - containing content and other information. + containing content and other information. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments. """ def on_llm_end( @@ -65,7 +69,14 @@ class LLMManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run when LLM ends running.""" + """Run when LLM ends running. + + Args: + response (LLMResult): The response which was generated. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments. + """ def on_llm_error( self, @@ -76,11 +87,12 @@ class LLMManagerMixin: **kwargs: Any, ) -> Any: """Run when LLM errors. + Args: error (BaseException): The error that occurred. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. kwargs (Any): Additional keyword arguments. - - response (LLMResult): The response which was generated before - the error occurred. """ @@ -95,7 +107,13 @@ class ChainManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run when chain ends running.""" + """Run when chain ends running. + + Args: + outputs (Dict[str, Any]): The outputs of the chain. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" def on_chain_error( self, @@ -105,7 +123,13 @@ class ChainManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run when chain errors.""" + """Run when chain errors. + + Args: + error (BaseException): The error that occurred. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" def on_agent_action( self, @@ -115,7 +139,13 @@ class ChainManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run on agent action.""" + """Run on agent action. + + Args: + action (AgentAction): The agent action. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" def on_agent_finish( self, @@ -125,7 +155,13 @@ class ChainManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run on agent end.""" + """Run on the agent end. + + Args: + finish (AgentFinish): The agent finish. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" class ToolManagerMixin: @@ -139,7 +175,13 @@ class ToolManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run when tool ends running.""" + """Run when the tool ends running. + + Args: + output (Any): The output of the tool. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" def on_tool_error( self, @@ -149,7 +191,13 @@ class ToolManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run when tool errors.""" + """Run when tool errors. + + Args: + error (BaseException): The error that occurred. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" class CallbackManagerMixin: @@ -171,6 +219,15 @@ class CallbackManagerMixin: **ATTENTION**: This method is called for non-chat models (regular LLMs). If you're implementing a handler for a chat model, you should use on_chat_model_start instead. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + prompts (List[str]): The prompts. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. """ def on_chat_model_start( @@ -188,6 +245,15 @@ class CallbackManagerMixin: **ATTENTION**: This method is called for chat models. If you're implementing a handler for a non-chat model, you should use on_llm_start instead. + + Args: + serialized (Dict[str, Any]): The serialized chat model. + messages (List[List[BaseMessage]]): The messages. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. """ # NotImplementedError is thrown intentionally # Callback handler will fall back to on_llm_start if this is exception is thrown @@ -206,7 +272,17 @@ class CallbackManagerMixin: metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - """Run when Retriever starts running.""" + """Run when the Retriever starts running. + + Args: + serialized (Dict[str, Any]): The serialized Retriever. + query (str): The query. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. + """ def on_chain_start( self, @@ -219,7 +295,17 @@ class CallbackManagerMixin: metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - """Run when chain starts running.""" + """Run when a chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Dict[str, Any]): The inputs. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. + """ def on_tool_start( self, @@ -233,7 +319,18 @@ class CallbackManagerMixin: inputs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - """Run when tool starts running.""" + """Run when the tool starts running. + + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input string. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + inputs (Optional[Dict[str, Any]]): The inputs. + kwargs (Any): Additional keyword arguments. + """ class RunManagerMixin: @@ -247,7 +344,14 @@ class RunManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run on arbitrary text.""" + """Run on an arbitrary text. + + Args: + text (str): The text. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments. + """ def on_retry( self, @@ -257,7 +361,14 @@ class RunManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run on a retry event.""" + """Run on a retry event. + + Args: + retry_state (RetryCallState): The retry state. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments. + """ class BaseCallbackHandler( @@ -268,11 +379,13 @@ class BaseCallbackHandler( CallbackManagerMixin, RunManagerMixin, ): - """Base callback handler that handles callbacks from LangChain.""" + """Base callback handler for LangChain.""" raise_error: bool = False + """Whether to raise an error if an exception occurs.""" run_inline: bool = False + """Whether to run the callback inline.""" @property def ignore_llm(self) -> bool: @@ -306,7 +419,7 @@ class BaseCallbackHandler( class AsyncCallbackHandler(BaseCallbackHandler): - """Async callback handler that handles callbacks from LangChain.""" + """Async callback handler for LangChain.""" async def on_llm_start( self, @@ -324,6 +437,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): **ATTENTION**: This method is called for non-chat models (regular LLMs). If you're implementing a handler for a chat model, you should use on_chat_model_start instead. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + prompts (List[str]): The prompts. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. """ async def on_chat_model_start( @@ -341,6 +463,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): **ATTENTION**: This method is called for chat models. If you're implementing a handler for a non-chat model, you should use on_llm_start instead. + + Args: + serialized (Dict[str, Any]): The serialized chat model. + messages (List[List[BaseMessage]]): The messages. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. """ # NotImplementedError is thrown intentionally # Callback handler will fall back to on_llm_start if this is exception is thrown @@ -358,7 +489,17 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" + """Run on new LLM token. Only available when streaming is enabled. + + Args: + token (str): The new token. + chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, + containing content and other information. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_llm_end( self, @@ -369,7 +510,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run when LLM ends running.""" + """Run when LLM ends running. + + Args: + response (LLMResult): The response which was generated. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_llm_error( self, @@ -384,6 +533,9 @@ class AsyncCallbackHandler(BaseCallbackHandler): Args: error: The error that occurred. + run_id: The run ID. This is the ID of the current run. + parent_run_id: The parent run ID. This is the ID of the parent run. + tags: The tags. kwargs (Any): Additional keyword arguments. - response (LLMResult): The response which was generated before the error occurred. @@ -400,7 +552,17 @@ class AsyncCallbackHandler(BaseCallbackHandler): metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - """Run when chain starts running.""" + """Run when a chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Dict[str, Any]): The inputs. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. + """ async def on_chain_end( self, @@ -411,7 +573,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run when chain ends running.""" + """Run when a chain ends running. + + Args: + outputs (Dict[str, Any]): The outputs of the chain. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_chain_error( self, @@ -422,7 +592,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run when chain errors.""" + """Run when chain errors. + + Args: + error (BaseException): The error that occurred. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_tool_start( self, @@ -436,7 +614,18 @@ class AsyncCallbackHandler(BaseCallbackHandler): inputs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - """Run when tool starts running.""" + """Run when the tool starts running. + + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input string. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + inputs (Optional[Dict[str, Any]]): The inputs. + kwargs (Any): Additional keyword arguments. + """ async def on_tool_end( self, @@ -447,7 +636,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run when tool ends running.""" + """Run when the tool ends running. + + Args: + output (Any): The output of the tool. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_tool_error( self, @@ -458,7 +655,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run when tool errors.""" + """Run when tool errors. + + Args: + error (BaseException): The error that occurred. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_text( self, @@ -469,7 +674,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run on arbitrary text.""" + """Run on an arbitrary text. + + Args: + text (str): The text. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_retry( self, @@ -479,7 +692,14 @@ class AsyncCallbackHandler(BaseCallbackHandler): parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run on a retry event.""" + """Run on a retry event. + + Args: + retry_state (RetryCallState): The retry state. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments. + """ async def on_agent_action( self, @@ -490,7 +710,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run on agent action.""" + """Run on agent action. + + Args: + action (AgentAction): The agent action. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_agent_finish( self, @@ -501,7 +729,15 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run on agent end.""" + """Run on the agent end. + + Args: + finish (AgentFinish): The agent finish. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ async def on_retriever_start( self, @@ -514,7 +750,17 @@ class AsyncCallbackHandler(BaseCallbackHandler): metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - """Run on retriever start.""" + """Run on the retriever start. + + Args: + serialized (Dict[str, Any]): The serialized retriever. + query (str): The query. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. + """ async def on_retriever_end( self, @@ -525,7 +771,14 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run on retriever end.""" + """Run on the retriever end. + + Args: + documents (Sequence[Document]): The documents retrieved. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments.""" async def on_retriever_error( self, @@ -536,14 +789,22 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run on retriever error.""" + """Run on retriever error. + + Args: + error (BaseException): The error that occurred. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + kwargs (Any): Additional keyword arguments. + """ T = TypeVar("T", bound="BaseCallbackManager") class BaseCallbackManager(CallbackManagerMixin): - """Base callback manager that handles callbacks from LangChain.""" + """Base callback manager for LangChain.""" def __init__( self, @@ -556,7 +817,18 @@ class BaseCallbackManager(CallbackManagerMixin): metadata: Optional[Dict[str, Any]] = None, inheritable_metadata: Optional[Dict[str, Any]] = None, ) -> None: - """Initialize callback manager.""" + """Initialize callback manager. + + Args: + handlers (List[BaseCallbackHandler]): The handlers. + inheritable_handlers (Optional[List[BaseCallbackHandler]]): + The inheritable handlers. Default is None. + parent_run_id (Optional[UUID]): The parent run ID. Default is None. + tags (Optional[List[str]]): The tags. Default is None. + inheritable_tags (Optional[List[str]]): The inheritable tags. + Default is None. + metadata (Optional[Dict[str, Any]]): The metadata. Default is None. + """ self.handlers: List[BaseCallbackHandler] = handlers self.inheritable_handlers: List[BaseCallbackHandler] = ( inheritable_handlers or [] @@ -585,31 +857,56 @@ class BaseCallbackManager(CallbackManagerMixin): return False def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: - """Add a handler to the callback manager.""" + """Add a handler to the callback manager. + + Args: + handler (BaseCallbackHandler): The handler to add. + inherit (bool): Whether to inherit the handler. Default is True. + """ if handler not in self.handlers: self.handlers.append(handler) if inherit and handler not in self.inheritable_handlers: self.inheritable_handlers.append(handler) def remove_handler(self, handler: BaseCallbackHandler) -> None: - """Remove a handler from the callback manager.""" + """Remove a handler from the callback manager. + + Args: + handler (BaseCallbackHandler): The handler to remove. + """ self.handlers.remove(handler) self.inheritable_handlers.remove(handler) def set_handlers( self, handlers: List[BaseCallbackHandler], inherit: bool = True ) -> None: - """Set handlers as the only handlers on the callback manager.""" + """Set handlers as the only handlers on the callback manager. + + Args: + handlers (List[BaseCallbackHandler]): The handlers to set. + inherit (bool): Whether to inherit the handlers. Default is True. + """ self.handlers = [] self.inheritable_handlers = [] for handler in handlers: self.add_handler(handler, inherit=inherit) def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: - """Set handler as the only handler on the callback manager.""" + """Set handler as the only handler on the callback manager. + + Args: + handler (BaseCallbackHandler): The handler to set. + inherit (bool): Whether to inherit the handler. Default is True. + """ self.set_handlers([handler], inherit=inherit) def add_tags(self, tags: List[str], inherit: bool = True) -> None: + """Add tags to the callback manager. + + Args: + tags (List[str]): The tags to add. + inherit (bool): Whether to inherit the tags. Default is True. + """ for tag in tags: if tag in self.tags: self.remove_tags([tag]) @@ -618,16 +915,32 @@ class BaseCallbackManager(CallbackManagerMixin): self.inheritable_tags.extend(tags) def remove_tags(self, tags: List[str]) -> None: + """Remove tags from the callback manager. + + Args: + tags (List[str]): The tags to remove. + """ for tag in tags: self.tags.remove(tag) self.inheritable_tags.remove(tag) def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: + """Add metadata to the callback manager. + + Args: + metadata (Dict[str, Any]): The metadata to add. + inherit (bool): Whether to inherit the metadata. Default is True. + """ self.metadata.update(metadata) if inherit: self.inheritable_metadata.update(metadata) def remove_metadata(self, keys: List[str]) -> None: + """Remove metadata from the callback manager. + + Args: + keys (List[str]): The keys to remove. + """ for key in keys: self.metadata.pop(key) self.inheritable_metadata.pop(key) diff --git a/libs/core/langchain_core/callbacks/file.py b/libs/core/langchain_core/callbacks/file.py index daef5294504..c33bd4a441c 100644 --- a/libs/core/langchain_core/callbacks/file.py +++ b/libs/core/langchain_core/callbacks/file.py @@ -10,12 +10,23 @@ from langchain_core.utils.input import print_text class FileCallbackHandler(BaseCallbackHandler): - """Callback Handler that writes to a file.""" + """Callback Handler that writes to a file. + + Parameters: + file: The file to write to. + color: The color to use for the text. + """ def __init__( self, filename: str, mode: str = "a", color: Optional[str] = None ) -> None: - """Initialize callback handler.""" + """Initialize callback handler. + + Args: + filename: The filename to write to. + mode: The mode to open the file in. Defaults to "a". + color: The color to use for the text. Defaults to None. + """ self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) self.color = color @@ -26,7 +37,13 @@ class FileCallbackHandler(BaseCallbackHandler): def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: - """Print out that we are entering a chain.""" + """Print out that we are entering a chain. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Dict[str, Any]): The inputs to the chain. + **kwargs (Any): Additional keyword arguments. + """ class_name = serialized.get("name", serialized.get("id", [""])[-1]) print_text( f"\n\n\033[1m> Entering new {class_name} chain...\033[0m", @@ -35,13 +52,25 @@ class FileCallbackHandler(BaseCallbackHandler): ) def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" + """Print out that we finished a chain. + + Args: + outputs (Dict[str, Any]): The outputs of the chain. + **kwargs (Any): Additional keyword arguments. + """ print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file) def on_agent_action( self, action: AgentAction, color: Optional[str] = None, **kwargs: Any ) -> Any: - """Run on agent action.""" + """Run on agent action. + + Args: + action (AgentAction): The agent action. + color (Optional[str], optional): The color to use for the text. + Defaults to None. + **kwargs (Any): Additional keyword arguments. + """ print_text(action.log, color=color or self.color, file=self.file) def on_tool_end( @@ -52,7 +81,18 @@ class FileCallbackHandler(BaseCallbackHandler): llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: - """If not the final action, print out observation.""" + """If not the final action, print out observation. + + Args: + output (str): The output to print. + color (Optional[str], optional): The color to use for the text. + Defaults to None. + observation_prefix (Optional[str], optional): The observation prefix. + Defaults to None. + llm_prefix (Optional[str], optional): The LLM prefix. + Defaults to None. + **kwargs (Any): Additional keyword arguments. + """ if observation_prefix is not None: print_text(f"\n{observation_prefix}", file=self.file) print_text(output, color=color or self.color, file=self.file) @@ -62,11 +102,26 @@ class FileCallbackHandler(BaseCallbackHandler): def on_text( self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any ) -> None: - """Run when agent ends.""" + """Run when the agent ends. + + Args: + text (str): The text to print. + color (Optional[str], optional): The color to use for the text. + Defaults to None. + end (str, optional): The end character. Defaults to "". + **kwargs (Any): Additional keyword arguments. + """ print_text(text, color=color or self.color, end=end, file=self.file) def on_agent_finish( self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any ) -> None: - """Run on agent end.""" + """Run on the agent end. + + Args: + finish (AgentFinish): The agent finish. + color (Optional[str], optional): The color to use for the text. + Defaults to None. + **kwargs (Any): Additional keyword arguments. + """ print_text(finish.log, color=color or self.color, end="\n", file=self.file) diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 8d3ecb6c059..80d785ed120 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -77,7 +77,9 @@ def trace_as_chain_group( Args: group_name (str): The name of the chain group. callback_manager (CallbackManager, optional): The callback manager to use. + Defaults to None. inputs (Dict[str, Any], optional): The inputs to the chain group. + Defaults to None. project_name (str, optional): The name of the project. Defaults to None. example_id (str or UUID, optional): The ID of the example. @@ -155,7 +157,9 @@ async def atrace_as_chain_group( Args: group_name (str): The name of the chain group. callback_manager (AsyncCallbackManager, optional): The async callback manager to use, - which manages tracing and other callback behavior. + which manages tracing and other callback behavior. Defaults to None. + inputs (Dict[str, Any], optional): The inputs to the chain group. + Defaults to None. project_name (str, optional): The name of the project. Defaults to None. example_id (str or UUID, optional): The ID of the example. @@ -218,7 +222,13 @@ Func = TypeVar("Func", bound=Callable) def shielded(func: Func) -> Func: """ - Makes so an awaitable method is always shielded from cancellation + Makes so an awaitable method is always shielded from cancellation. + + Args: + func (Callable): The function to shield. + + Returns: + Callable: The shielded function """ @functools.wraps(func) @@ -237,14 +247,14 @@ def handle_event( ) -> None: """Generic event handler for CallbackManager. - Note: This function is used by langserve to handle events. + Note: This function is used by LangServe to handle events. Args: - handlers: The list of handlers that will handle the event - event_name: The name of the event (e.g., "on_llm_start") + handlers: The list of handlers that will handle the event. + event_name: The name of the event (e.g., "on_llm_start"). ignore_condition_name: Name of the attribute defined on handler - that if True will cause the handler to be skipped for the given event - *args: The arguments to pass to the event handler + that if True will cause the handler to be skipped for the given event. + *args: The arguments to pass to the event handler. **kwargs: The keyword arguments to pass to the event handler """ coros: List[Coroutine[Any, Any, Any]] = [] @@ -394,17 +404,17 @@ async def ahandle_event( *args: Any, **kwargs: Any, ) -> None: - """Generic event handler for AsyncCallbackManager. + """Async generic event handler for AsyncCallbackManager. - Note: This function is used by langserve to handle events. + Note: This function is used by LangServe to handle events. Args: - handlers: The list of handlers that will handle the event - event_name: The name of the event (e.g., "on_llm_start") + handlers: The list of handlers that will handle the event. + event_name: The name of the event (e.g., "on_llm_start"). ignore_condition_name: Name of the attribute defined on handler - that if True will cause the handler to be skipped for the given event - *args: The arguments to pass to the event handler - **kwargs: The keyword arguments to pass to the event handler + that if True will cause the handler to be skipped for the given event. + *args: The arguments to pass to the event handler. + **kwargs: The keyword arguments to pass to the event handler. """ for handler in [h for h in handlers if h.run_inline]: await _ahandle_event_for_handler( @@ -452,10 +462,13 @@ class BaseRunManager(RunManagerMixin): The list of inheritable handlers. parent_run_id (UUID, optional): The ID of the parent run. Defaults to None. - tags (Optional[List[str]]): The list of tags. + tags (Optional[List[str]]): The list of tags. Defaults to None. inheritable_tags (Optional[List[str]]): The list of inheritable tags. + Defaults to None. metadata (Optional[Dict[str, Any]]): The metadata. + Defaults to None. inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata. + Defaults to None. """ self.run_id = run_id self.handlers = handlers @@ -492,10 +505,11 @@ class RunManager(BaseRunManager): text: str, **kwargs: Any, ) -> Any: - """Run when text is received. + """Run when a text is received. Args: text (str): The received text. + **kwargs (Any): Additional keyword arguments. Returns: Any: The result of the callback. @@ -516,6 +530,12 @@ class RunManager(BaseRunManager): retry_state: RetryCallState, **kwargs: Any, ) -> None: + """Run when a retry is received. + + Args: + retry_state (RetryCallState): The retry state. + **kwargs (Any): Additional keyword arguments. + """ handle_event( self.handlers, "on_retry", @@ -566,10 +586,11 @@ class AsyncRunManager(BaseRunManager, ABC): text: str, **kwargs: Any, ) -> Any: - """Run when text is received. + """Run when a text is received. Args: text (str): The received text. + **kwargs (Any): Additional keyword arguments. Returns: Any: The result of the callback. @@ -590,6 +611,12 @@ class AsyncRunManager(BaseRunManager, ABC): retry_state: RetryCallState, **kwargs: Any, ) -> None: + """Async run when a retry is received. + + Args: + retry_state (RetryCallState): The retry state. + **kwargs (Any): Additional keyword arguments. + """ await ahandle_event( self.handlers, "on_retry", @@ -638,6 +665,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): Args: token (str): The new token. + chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional): + The chunk. Defaults to None. + **kwargs (Any): Additional keyword arguments. """ handle_event( self.handlers, @@ -656,6 +686,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): Args: response (LLMResult): The LLM result. + **kwargs (Any): Additional keyword arguments. """ handle_event( self.handlers, @@ -725,6 +756,9 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): Args: token (str): The new token. + chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional): + The chunk. Defaults to None. + **kwargs (Any): Additional keyword arguments. """ await ahandle_event( self.handlers, @@ -744,6 +778,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): Args: response (LLMResult): The LLM result. + **kwargs (Any): Additional keyword arguments. """ await ahandle_event( self.handlers, @@ -793,6 +828,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): Args: outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + **kwargs (Any): Additional keyword arguments. """ handle_event( self.handlers, @@ -814,6 +850,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. + **kwargs (Any): Additional keyword arguments. """ handle_event( self.handlers, @@ -831,6 +868,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): Args: action (AgentAction): The agent action. + **kwargs (Any): Additional keyword arguments. Returns: Any: The result of the callback. @@ -851,6 +889,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): Args: finish (AgentFinish): The agent finish. + **kwargs (Any): Additional keyword arguments. Returns: Any: The result of the callback. @@ -891,10 +930,11 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): async def on_chain_end( self, outputs: Union[Dict[str, Any], Any], **kwargs: Any ) -> None: - """Run when chain ends running. + """Run when a chain ends running. Args: outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + **kwargs (Any): Additional keyword arguments. """ await ahandle_event( self.handlers, @@ -917,6 +957,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. + **kwargs (Any): Additional keyword arguments. """ await ahandle_event( self.handlers, @@ -935,6 +976,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): Args: action (AgentAction): The agent action. + **kwargs (Any): Additional keyword arguments. Returns: Any: The result of the callback. @@ -956,6 +998,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): Args: finish (AgentFinish): The agent finish. + **kwargs (Any): Additional keyword arguments. Returns: Any: The result of the callback. @@ -980,10 +1023,11 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): output: Any, **kwargs: Any, ) -> None: - """Run when tool ends running. + """Run when the tool ends running. Args: output (Any): The output of the tool. + **kwargs (Any): Additional keyword arguments. """ handle_event( self.handlers, @@ -1005,6 +1049,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. + **kwargs (Any): Additional keyword arguments. """ handle_event( self.handlers, @@ -1040,10 +1085,11 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): @shielded async def on_tool_end(self, output: Any, **kwargs: Any) -> None: - """Run when tool ends running. + """Async run when the tool ends running. Args: output (Any): The output of the tool. + **kwargs (Any): Additional keyword arguments. """ await ahandle_event( self.handlers, @@ -1066,6 +1112,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. + **kwargs (Any): Additional keyword arguments. """ await ahandle_event( self.handlers, @@ -1087,7 +1134,12 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): documents: Sequence[Document], **kwargs: Any, ) -> None: - """Run when retriever ends running.""" + """Run when retriever ends running. + + Args: + documents (Sequence[Document]): The retrieved documents. + **kwargs (Any): Additional keyword arguments. + """ handle_event( self.handlers, "on_retriever_end", @@ -1104,7 +1156,12 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): error: BaseException, **kwargs: Any, ) -> None: - """Run when retriever errors.""" + """Run when retriever errors. + + Args: + error (BaseException): The error. + **kwargs (Any): Additional keyword arguments. + """ handle_event( self.handlers, "on_retriever_error", @@ -1144,7 +1201,12 @@ class AsyncCallbackManagerForRetrieverRun( async def on_retriever_end( self, documents: Sequence[Document], **kwargs: Any ) -> None: - """Run when retriever ends running.""" + """Run when the retriever ends running. + + Args: + documents (Sequence[Document]): The retrieved documents. + **kwargs (Any): Additional keyword arguments. + """ await ahandle_event( self.handlers, "on_retriever_end", @@ -1162,7 +1224,12 @@ class AsyncCallbackManagerForRetrieverRun( error: BaseException, **kwargs: Any, ) -> None: - """Run when retriever errors.""" + """Run when retriever errors. + + Args: + error (BaseException): The error. + **kwargs (Any): Additional keyword arguments. + """ await ahandle_event( self.handlers, "on_retriever_error", @@ -1176,7 +1243,7 @@ class AsyncCallbackManagerForRetrieverRun( class CallbackManager(BaseCallbackManager): - """Callback manager that handles callbacks from LangChain.""" + """Callback manager for LangChain.""" def on_llm_start( self, @@ -1191,6 +1258,7 @@ class CallbackManager(BaseCallbackManager): serialized (Dict[str, Any]): The serialized LLM. prompts (List[str]): The list of prompts. run_id (UUID, optional): The ID of the run. Defaults to None. + **kwargs (Any): Additional keyword arguments. Returns: List[CallbackManagerForLLMRun]: A callback manager for each @@ -1241,6 +1309,7 @@ class CallbackManager(BaseCallbackManager): serialized (Dict[str, Any]): The serialized LLM. messages (List[List[BaseMessage]]): The list of messages. run_id (UUID, optional): The ID of the run. Defaults to None. + **kwargs (Any): Additional keyword arguments. Returns: List[CallbackManagerForLLMRun]: A callback manager for each @@ -1295,6 +1364,7 @@ class CallbackManager(BaseCallbackManager): serialized (Dict[str, Any]): The serialized chain. inputs (Union[Dict[str, Any], Any]): The inputs to the chain. run_id (UUID, optional): The ID of the run. Defaults to None. + **kwargs (Any): Additional keyword arguments. Returns: CallbackManagerForChainRun: The callback manager for the chain run. @@ -1347,6 +1417,7 @@ class CallbackManager(BaseCallbackManager): input is needed. If provided, the inputs are expected to be formatted as a dict. The keys will correspond to the named-arguments in the tool. + **kwargs (Any): Additional keyword arguments. Returns: CallbackManagerForToolRun: The callback manager for the tool run. @@ -1387,7 +1458,15 @@ class CallbackManager(BaseCallbackManager): parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> CallbackManagerForRetrieverRun: - """Run when retriever starts running.""" + """Run when the retriever starts running. + + Args: + serialized (Dict[str, Any]): The serialized retriever. + query (str): The query. + run_id (UUID, optional): The ID of the run. Defaults to None. + parent_run_id (UUID, optional): The ID of the parent run. Defaults to None. + **kwargs (Any): Additional keyword arguments. + """ if run_id is None: run_id = uuid.uuid4() @@ -1470,6 +1549,16 @@ class CallbackManagerForChainGroup(CallbackManager): parent_run_manager: CallbackManagerForChainRun, **kwargs: Any, ) -> None: + """Initialize the callback manager. + + Args: + handlers (List[BaseCallbackHandler]): The list of handlers. + inheritable_handlers (Optional[List[BaseCallbackHandler]]): The list of + inheritable handlers. Defaults to None. + parent_run_id (Optional[UUID]): The ID of the parent run. Defaults to None. + parent_run_manager (CallbackManagerForChainRun): The parent run manager. + **kwargs (Any): Additional keyword arguments. + """ super().__init__( handlers, inheritable_handlers, @@ -1480,6 +1569,7 @@ class CallbackManagerForChainGroup(CallbackManager): self.ended = False def copy(self) -> CallbackManagerForChainGroup: + """Copy the callback manager.""" return self.__class__( handlers=self.handlers, inheritable_handlers=self.inheritable_handlers, @@ -1496,6 +1586,7 @@ class CallbackManagerForChainGroup(CallbackManager): Args: outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + **kwargs (Any): Additional keyword arguments. """ self.ended = True return self.parent_run_manager.on_chain_end(outputs, **kwargs) @@ -1509,6 +1600,7 @@ class CallbackManagerForChainGroup(CallbackManager): Args: error (Exception or KeyboardInterrupt): The error. + **kwargs (Any): Additional keyword arguments. """ self.ended = True return self.parent_run_manager.on_chain_error(error, **kwargs) @@ -1535,6 +1627,7 @@ class AsyncCallbackManager(BaseCallbackManager): serialized (Dict[str, Any]): The serialized LLM. prompts (List[str]): The list of prompts. run_id (UUID, optional): The ID of the run. Defaults to None. + **kwargs (Any): Additional keyword arguments. Returns: List[AsyncCallbackManagerForLLMRun]: The list of async @@ -1591,12 +1684,13 @@ class AsyncCallbackManager(BaseCallbackManager): run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[AsyncCallbackManagerForLLMRun]: - """Run when LLM starts running. + """Async run when LLM starts running. Args: serialized (Dict[str, Any]): The serialized LLM. messages (List[List[BaseMessage]]): The list of messages. run_id (UUID, optional): The ID of the run. Defaults to None. + **kwargs (Any): Additional keyword arguments. Returns: List[AsyncCallbackManagerForLLMRun]: The list of @@ -1651,12 +1745,13 @@ class AsyncCallbackManager(BaseCallbackManager): run_id: Optional[UUID] = None, **kwargs: Any, ) -> AsyncCallbackManagerForChainRun: - """Run when chain starts running. + """Async run when chain starts running. Args: serialized (Dict[str, Any]): The serialized chain. inputs (Union[Dict[str, Any], Any]): The inputs to the chain. run_id (UUID, optional): The ID of the run. Defaults to None. + **kwargs (Any): Additional keyword arguments. Returns: AsyncCallbackManagerForChainRun: The async callback manager @@ -1697,7 +1792,7 @@ class AsyncCallbackManager(BaseCallbackManager): parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> AsyncCallbackManagerForToolRun: - """Run when tool starts running. + """Run when the tool starts running. Args: serialized (Dict[str, Any]): The serialized tool. @@ -1705,6 +1800,7 @@ class AsyncCallbackManager(BaseCallbackManager): run_id (UUID, optional): The ID of the run. Defaults to None. parent_run_id (UUID, optional): The ID of the parent run. Defaults to None. + **kwargs (Any): Additional keyword arguments. Returns: AsyncCallbackManagerForToolRun: The async callback manager @@ -1745,7 +1841,19 @@ class AsyncCallbackManager(BaseCallbackManager): parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> AsyncCallbackManagerForRetrieverRun: - """Run when retriever starts running.""" + """Run when the retriever starts running. + + Args: + serialized (Dict[str, Any]): The serialized retriever. + query (str): The query. + run_id (UUID, optional): The ID of the run. Defaults to None. + parent_run_id (UUID, optional): The ID of the parent run. Defaults to None. + **kwargs (Any): Additional keyword arguments. + + Returns: + AsyncCallbackManagerForRetrieverRun: The async callback manager + for the retriever run. + """ if run_id is None: run_id = uuid.uuid4() @@ -1828,6 +1936,17 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): parent_run_manager: AsyncCallbackManagerForChainRun, **kwargs: Any, ) -> None: + """Initialize the async callback manager. + + Args: + handlers (List[BaseCallbackHandler]): The list of handlers. + inheritable_handlers (Optional[List[BaseCallbackHandler]]): The list of + inheritable handlers. Defaults to None. + parent_run_id (Optional[UUID]): The ID of the parent run. Defaults to None. + parent_run_manager (AsyncCallbackManagerForChainRun): + The parent run manager. + **kwargs (Any): Additional keyword arguments. + """ super().__init__( handlers, inheritable_handlers, @@ -1838,6 +1957,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): self.ended = False def copy(self) -> AsyncCallbackManagerForChainGroup: + """Copy the async callback manager.""" return self.__class__( handlers=self.handlers, inheritable_handlers=self.inheritable_handlers, @@ -1856,6 +1976,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): Args: outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + **kwargs (Any): Additional keyword arguments. """ self.ended = True await self.parent_run_manager.on_chain_end(outputs, **kwargs) @@ -1869,6 +1990,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): Args: error (Exception or KeyboardInterrupt): The error. + **kwargs (Any): Additional keyword arguments. """ self.ended = True await self.parent_run_manager.on_chain_error(error, **kwargs) diff --git a/libs/core/langchain_core/callbacks/stdout.py b/libs/core/langchain_core/callbacks/stdout.py index 0408eb19845..011dc83fcb9 100644 --- a/libs/core/langchain_core/callbacks/stdout.py +++ b/libs/core/langchain_core/callbacks/stdout.py @@ -15,24 +15,45 @@ class StdOutCallbackHandler(BaseCallbackHandler): """Callback Handler that prints to std out.""" def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" + """Initialize callback handler. + + Args: + color: The color to use for the text. Defaults to None. + """ self.color = color def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: - """Print out that we are entering a chain.""" + """Print out that we are entering a chain. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Dict[str, Any]): The inputs to the chain. + **kwargs (Any): Additional keyword arguments. + """ class_name = serialized.get("name", serialized.get("id", [""])[-1]) print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201 def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" + """Print out that we finished a chain. + + Args: + outputs (Dict[str, Any]): The outputs of the chain. + **kwargs (Any): Additional keyword arguments. + """ print("\n\033[1m> Finished chain.\033[0m") # noqa: T201 def on_agent_action( self, action: AgentAction, color: Optional[str] = None, **kwargs: Any ) -> Any: - """Run on agent action.""" + """Run on agent action. + + Args: + action (AgentAction): The agent action. + color (Optional[str]): The color to use for the text. Defaults to None. + **kwargs (Any): Additional keyword arguments. + """ print_text(action.log, color=color or self.color) def on_tool_end( @@ -43,7 +64,16 @@ class StdOutCallbackHandler(BaseCallbackHandler): llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: - """If not the final action, print out observation.""" + """If not the final action, print out observation. + + Args: + output (Any): The output to print. + color (Optional[str]): The color to use for the text. Defaults to None. + observation_prefix (Optional[str]): The observation prefix. + Defaults to None. + llm_prefix (Optional[str]): The LLM prefix. Defaults to None. + **kwargs (Any): Additional keyword arguments. + """ output = str(output) if observation_prefix is not None: print_text(f"\n{observation_prefix}") @@ -58,11 +88,24 @@ class StdOutCallbackHandler(BaseCallbackHandler): end: str = "", **kwargs: Any, ) -> None: - """Run when agent ends.""" + """Run when the agent ends. + + Args: + text (str): The text to print. + color (Optional[str]): The color to use for the text. Defaults to None. + end (str): The end character to use. Defaults to "". + **kwargs (Any): Additional keyword arguments. + """ print_text(text, color=color or self.color, end=end) def on_agent_finish( self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any ) -> None: - """Run on agent end.""" + """Run on the agent end. + + Args: + finish (AgentFinish): The agent finish. + color (Optional[str]): The color to use for the text. Defaults to None. + **kwargs (Any): Additional keyword arguments. + """ print_text(finish.log, color=color or self.color, end="\n") diff --git a/libs/core/langchain_core/callbacks/streaming_stdout.py b/libs/core/langchain_core/callbacks/streaming_stdout.py index 4966aa2e11c..06973035a90 100644 --- a/libs/core/langchain_core/callbacks/streaming_stdout.py +++ b/libs/core/langchain_core/callbacks/streaming_stdout.py @@ -1,4 +1,5 @@ """Callback Handler streams to stdout on new llm token.""" + from __future__ import annotations import sys @@ -18,7 +19,13 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: - """Run when LLM starts running.""" + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + prompts (List[str]): The prompts to run. + **kwargs (Any): Additional keyword arguments. + """ def on_chat_model_start( self, @@ -26,47 +33,115 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): messages: List[List[BaseMessage]], **kwargs: Any, ) -> None: - """Run when LLM starts running.""" + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + messages (List[List[BaseMessage]]): The messages to run. + **kwargs (Any): Additional keyword arguments. + """ def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" + """Run on new LLM token. Only available when streaming is enabled. + + Args: + token (str): The new token. + **kwargs (Any): Additional keyword arguments. + """ sys.stdout.write(token) sys.stdout.flush() def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" + """Run when LLM ends running. + + Args: + response (LLMResult): The response from the LLM. + **kwargs (Any): Additional keyword arguments. + """ def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when LLM errors.""" + """Run when LLM errors. + + Args: + error (BaseException): The error that occurred. + **kwargs (Any): Additional keyword arguments. + """ def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: - """Run when chain starts running.""" + """Run when a chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Dict[str, Any]): The inputs to the chain. + **kwargs (Any): Additional keyword arguments. + """ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" + """Run when a chain ends running. + + Args: + outputs (Dict[str, Any]): The outputs of the chain. + **kwargs (Any): Additional keyword arguments. + """ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when chain errors.""" + """Run when chain errors. + + Args: + error (BaseException): The error that occurred. + **kwargs (Any): Additional keyword arguments. + """ def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: - """Run when tool starts running.""" + """Run when the tool starts running. + + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input string. + **kwargs (Any): Additional keyword arguments. + """ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" + """Run on agent action. + + Args: + action (AgentAction): The agent action. + **kwargs (Any): Additional keyword arguments. + """ pass def on_tool_end(self, output: Any, **kwargs: Any) -> None: - """Run when tool ends running.""" + """Run when tool ends running. + + Args: + output (Any): The output of the tool. + **kwargs (Any): Additional keyword arguments. + """ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when tool errors.""" + """Run when tool errors. + + Args: + error (BaseException): The error that occurred. + **kwargs (Any): Additional keyword arguments. + """ def on_text(self, text: str, **kwargs: Any) -> None: - """Run on arbitrary text.""" + """Run on an arbitrary text. + + Args: + text (str): The text to print. + **kwargs (Any): Additional keyword arguments. + """ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run on agent end.""" + """Run on the agent end. + + Args: + finish (AgentFinish): The agent finish. + **kwargs (Any): Additional keyword arguments. + """