diff --git a/libs/core/langchain_core/callbacks/__init__.py b/libs/core/langchain_core/callbacks/__init__.py index b2af179fa00..65df88d69e5 100644 --- a/libs/core/langchain_core/callbacks/__init__.py +++ b/libs/core/langchain_core/callbacks/__init__.py @@ -18,6 +18,7 @@ from langchain_core.callbacks.base import ( RunManagerMixin, ToolManagerMixin, ) +from langchain_core.callbacks.file import FileCallbackHandler from langchain_core.callbacks.manager import ( AsyncCallbackManager, AsyncCallbackManagerForChainGroup, @@ -70,4 +71,5 @@ __all__ = [ "AsyncCallbackManagerForChainGroup", "StdOutCallbackHandler", "StreamingStdOutCallbackHandler", + "FileCallbackHandler", ] diff --git a/libs/core/langchain_core/callbacks/file.py b/libs/core/langchain_core/callbacks/file.py new file mode 100644 index 00000000000..daef5294504 --- /dev/null +++ b/libs/core/langchain_core/callbacks/file.py @@ -0,0 +1,72 @@ +"""Callback Handler that writes to a file.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, TextIO, cast + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.utils.input import print_text + + +class FileCallbackHandler(BaseCallbackHandler): + """Callback Handler that writes to a file.""" + + def __init__( + self, filename: str, mode: str = "a", color: Optional[str] = None + ) -> None: + """Initialize callback handler.""" + self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) + self.color = color + + def __del__(self) -> None: + """Destructor to cleanup when done.""" + self.file.close() + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + print_text( + f"\n\n\033[1m> Entering new {class_name} chain...\033[0m", + end="\n", + file=self.file, + ) + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we finished a chain.""" + print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file) + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + """Run on agent action.""" + print_text(action.log, color=color or self.color, file=self.file) + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + if observation_prefix is not None: + print_text(f"\n{observation_prefix}", file=self.file) + print_text(output, color=color or self.color, file=self.file) + if llm_prefix is not None: + print_text(f"\n{llm_prefix}", file=self.file) + + def on_text( + self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any + ) -> None: + """Run when agent ends.""" + print_text(text, color=color or self.color, end=end, file=self.file) + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + """Run on agent end.""" + print_text(finish.log, color=color or self.color, end="\n", file=self.file) diff --git a/libs/core/tests/unit_tests/callbacks/test_imports.py b/libs/core/tests/unit_tests/callbacks/test_imports.py index 9a9ee0eb9b8..1baca2a2fe3 100644 --- a/libs/core/tests/unit_tests/callbacks/test_imports.py +++ b/libs/core/tests/unit_tests/callbacks/test_imports.py @@ -30,6 +30,7 @@ EXPECTED_ALL = [ "AsyncCallbackManagerForChainGroup", "StdOutCallbackHandler", "StreamingStdOutCallbackHandler", + "FileCallbackHandler", ] diff --git a/libs/langchain/langchain/callbacks/__init__.py b/libs/langchain/langchain/callbacks/__init__.py index 80d9671fde3..e6068d08488 100644 --- a/libs/langchain/langchain/callbacks/__init__.py +++ b/libs/langchain/langchain/callbacks/__init__.py @@ -11,6 +11,7 @@ from typing import Any from langchain_core._api import LangChainDeprecationWarning from langchain_core.callbacks import ( + FileCallbackHandler, StdOutCallbackHandler, StreamingStdOutCallbackHandler, ) @@ -21,7 +22,6 @@ from langchain_core.tracers.context import ( ) from langchain_core.tracers.langchain import LangChainTracer -from langchain.callbacks.file import FileCallbackHandler from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.callbacks.streaming_stdout_final_only import ( FinalStreamingStdOutCallbackHandler, diff --git a/libs/langchain/langchain/callbacks/file.py b/libs/langchain/langchain/callbacks/file.py index 06bcecb027d..15fa4101883 100644 --- a/libs/langchain/langchain/callbacks/file.py +++ b/libs/langchain/langchain/callbacks/file.py @@ -1,69 +1,3 @@ -"""Callback Handler that writes to a file.""" -from typing import Any, Dict, Optional, TextIO, cast +from langchain_core.callbacks.file import FileCallbackHandler -from langchain_core.agents import AgentAction, AgentFinish -from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.utils.input import print_text - - -class FileCallbackHandler(BaseCallbackHandler): - """Callback Handler that writes to a file.""" - - def __init__( - self, filename: str, mode: str = "a", color: Optional[str] = None - ) -> None: - """Initialize callback handler.""" - self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) - self.color = color - - def __del__(self) -> None: - """Destructor to cleanup when done.""" - self.file.close() - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - class_name = serialized.get("name", serialized.get("id", [""])[-1]) - print_text( - f"\n\n\033[1m> Entering new {class_name} chain...\033[0m", - end="\n", - file=self.file, - ) - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file) - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - print_text(action.log, color=color or self.color, file=self.file) - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - if observation_prefix is not None: - print_text(f"\n{observation_prefix}", file=self.file) - print_text(output, color=color or self.color, file=self.file) - if llm_prefix is not None: - print_text(f"\n{llm_prefix}", file=self.file) - - def on_text( - self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any - ) -> None: - """Run when agent ends.""" - print_text(text, color=color or self.color, end=end, file=self.file) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - print_text(finish.log, color=color or self.color, end="\n", file=self.file) +__all__ = ["FileCallbackHandler"]