diff --git a/langchain/callbacks/tracers/__init__.py b/langchain/callbacks/tracers/__init__.py index 6e7ce4110a2..65a5a846ca2 100644 --- a/langchain/callbacks/tracers/__init__.py +++ b/langchain/callbacks/tracers/__init__.py @@ -2,12 +2,16 @@ from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1 -from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler +from langchain.callbacks.tracers.stdout import ( + ConsoleCallbackHandler, + FunctionCallbackHandler, +) from langchain.callbacks.tracers.wandb import WandbTracer __all__ = [ "LangChainTracer", "LangChainTracerV1", + "FunctionCallbackHandler", "ConsoleCallbackHandler", "WandbTracer", ] diff --git a/langchain/callbacks/tracers/stdout.py b/langchain/callbacks/tracers/stdout.py index bfdaca4f465..b7121336a63 100644 --- a/langchain/callbacks/tracers/stdout.py +++ b/langchain/callbacks/tracers/stdout.py @@ -1,5 +1,5 @@ import json -from typing import Any, List +from typing import Any, Callable, List from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run @@ -41,10 +41,14 @@ def elapsed(run: Any) -> str: return f"{(milliseconds / 1000):.2f}s" -class ConsoleCallbackHandler(BaseTracer): - """Tracer that prints to the console.""" +class FunctionCallbackHandler(BaseTracer): + """Tracer that calls a function with a single str parameter.""" - name = "console_callback_handler" + name = "function_callback_handler" + + def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None: + super().__init__(**kwargs) + self.function_callback = function def _persist_run(self, run: Run) -> None: pass @@ -74,7 +78,7 @@ class ConsoleCallbackHandler(BaseTracer): # logging methods def _on_chain_start(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) - print( + self.function_callback( f"{get_colored_text('[chain/start]', color='green')} " + get_bolded_text(f"[{crumbs}] Entering Chain run with input:\n") + f"{try_json_stringify(run.inputs, '[inputs]')}" @@ -82,7 +86,7 @@ class ConsoleCallbackHandler(BaseTracer): def _on_chain_end(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) - print( + self.function_callback( f"{get_colored_text('[chain/end]', color='blue')} " + get_bolded_text( f"[{crumbs}] [{elapsed(run)}] Exiting Chain run with output:\n" @@ -92,7 +96,7 @@ class ConsoleCallbackHandler(BaseTracer): def _on_chain_error(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) - print( + self.function_callback( f"{get_colored_text('[chain/error]', color='red')} " + get_bolded_text( f"[{crumbs}] [{elapsed(run)}] Chain run errored with error:\n" @@ -107,7 +111,7 @@ class ConsoleCallbackHandler(BaseTracer): if "prompts" in run.inputs else run.inputs ) - print( + self.function_callback( f"{get_colored_text('[llm/start]', color='green')} " + get_bolded_text(f"[{crumbs}] Entering LLM run with input:\n") + f"{try_json_stringify(inputs, '[inputs]')}" @@ -115,7 +119,7 @@ class ConsoleCallbackHandler(BaseTracer): def _on_llm_end(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) - print( + self.function_callback( f"{get_colored_text('[llm/end]', color='blue')} " + get_bolded_text( f"[{crumbs}] [{elapsed(run)}] Exiting LLM run with output:\n" @@ -125,7 +129,7 @@ class ConsoleCallbackHandler(BaseTracer): def _on_llm_error(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) - print( + self.function_callback( f"{get_colored_text('[llm/error]', color='red')} " + get_bolded_text( f"[{crumbs}] [{elapsed(run)}] LLM run errored with error:\n" @@ -135,7 +139,7 @@ class ConsoleCallbackHandler(BaseTracer): def _on_tool_start(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) - print( + self.function_callback( f'{get_colored_text("[tool/start]", color="green")} ' + get_bolded_text(f"[{crumbs}] Entering Tool run with input:\n") + f'"{run.inputs["input"].strip()}"' @@ -144,7 +148,7 @@ class ConsoleCallbackHandler(BaseTracer): def _on_tool_end(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) if run.outputs: - print( + self.function_callback( f'{get_colored_text("[tool/end]", color="blue")} ' + get_bolded_text( f"[{crumbs}] [{elapsed(run)}] Exiting Tool run with output:\n" @@ -154,9 +158,18 @@ class ConsoleCallbackHandler(BaseTracer): def _on_tool_error(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) - print( + self.function_callback( f"{get_colored_text('[tool/error]', color='red')} " + get_bolded_text(f"[{crumbs}] [{elapsed(run)}] ") + f"Tool run errored with error:\n" f"{run.error}" ) + + +class ConsoleCallbackHandler(FunctionCallbackHandler): + """Tracer that prints to the console.""" + + name = "console_callback_handler" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(function=print, **kwargs)