From f90249305ae38468c157ffac17fc05c09a05045a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 16 Nov 2023 08:25:09 -0800 Subject: [PATCH] callback refactor (#13372) Co-authored-by: Nuno Campos --- libs/langchain/langchain/callbacks/base.py | 615 +---- libs/langchain/langchain/callbacks/manager.py | 2086 +---------------- libs/langchain/langchain/callbacks/stdout.py | 98 +- .../langchain/callbacks/tracers/__init__.py | 8 +- .../langchain/callbacks/tracers/base.py | 536 +---- .../langchain/callbacks/tracers/evaluation.py | 225 +- .../langchain/callbacks/tracers/langchain.py | 262 +-- .../callbacks/tracers/langchain_v1.py | 186 +- .../langchain/callbacks/tracers/log_stream.py | 316 +-- .../callbacks/tracers/root_listeners.py | 55 +- .../callbacks/tracers/run_collector.py | 53 +- .../langchain/callbacks/tracers/schemas.py | 139 +- .../langchain/callbacks/tracers/stdout.py | 182 +- .../langchain/schema/callbacks/__init__.py | 0 .../langchain/schema/callbacks/base.py | 598 +++++ .../langchain/schema/callbacks/manager.py | 2075 ++++++++++++++++ .../langchain/schema/callbacks/stdout.py | 97 + .../schema/callbacks/tracers/__init__.py | 0 .../schema/callbacks/tracers/base.py | 537 +++++ .../schema/callbacks/tracers/evaluation.py | 222 ++ .../schema/callbacks/tracers/langchain.py | 262 +++ .../schema/callbacks/tracers/langchain_v1.py | 185 ++ .../schema/callbacks/tracers/log_stream.py | 311 +++ .../callbacks/tracers/root_listeners.py | 54 + .../schema/callbacks/tracers/run_collector.py | 52 + .../schema/callbacks/tracers/schemas.py | 140 ++ .../schema/callbacks/tracers/stdout.py | 178 ++ .../langchain/schema/runnable/base.py | 6 +- libs/langchain/scripts/check_imports.sh | 2 +- .../callbacks/tracers/test_langchain_v1.py | 6 +- .../tests/unit_tests/test_globals.py | 4 +- 31 files changed, 4848 insertions(+), 4642 deletions(-) create mode 100644 libs/langchain/langchain/schema/callbacks/__init__.py create mode 100644 libs/langchain/langchain/schema/callbacks/base.py create mode 100644 libs/langchain/langchain/schema/callbacks/manager.py create mode 100644 libs/langchain/langchain/schema/callbacks/stdout.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/__init__.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/base.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/evaluation.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/langchain.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/log_stream.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/root_listeners.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/run_collector.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/schemas.py create mode 100644 libs/langchain/langchain/schema/callbacks/tracers/stdout.py diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index 12c7a29f13b..151235af2ef 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -1,599 +1,28 @@ """Base callback handler that can be used to handle callbacks in langchain.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union -from uuid import UUID - -from tenacity import RetryCallState - -if TYPE_CHECKING: - from langchain.schema.agent import AgentAction, AgentFinish - from langchain.schema.document import Document - from langchain.schema.messages import BaseMessage - from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult - - -class RetrieverManagerMixin: - """Mixin for Retriever callbacks.""" - - def on_retriever_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when Retriever errors.""" - - def on_retriever_end( - self, - documents: Sequence[Document], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when Retriever ends running.""" - - -class LLMManagerMixin: - """Mixin for LLM callbacks.""" - - def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on new LLM token. Only available when streaming is enabled. - - Args: - token (str): The new token. - chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, - containing content and other information. - """ - - def on_llm_end( - self, - response: LLMResult, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when LLM ends running.""" - - def on_llm_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when LLM errors.""" - - -class ChainManagerMixin: - """Mixin for chain callbacks.""" - - def on_chain_end( - self, - outputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when chain ends running.""" - - def on_chain_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when chain errors.""" - - def on_agent_action( - self, - action: AgentAction, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on agent action.""" - - def on_agent_finish( - self, - finish: AgentFinish, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on agent end.""" - - -class ToolManagerMixin: - """Mixin for tool callbacks.""" - - def on_tool_end( - self, - output: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when tool ends running.""" - - def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when tool errors.""" - - -class CallbackManagerMixin: - """Mixin for callback manager.""" - - def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when LLM starts running.""" - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when a chat model starts running.""" - raise NotImplementedError( - f"{self.__class__.__name__} does not implement `on_chat_model_start`" - ) - - def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when Retriever starts running.""" - - def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when chain starts running.""" - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when tool starts running.""" - - -class RunManagerMixin: - """Mixin for run manager.""" - - def on_text( - self, - text: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on arbitrary text.""" - - def on_retry( - self, - retry_state: RetryCallState, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on a retry event.""" - - -class BaseCallbackHandler( - LLMManagerMixin, - ChainManagerMixin, - ToolManagerMixin, - RetrieverManagerMixin, +from langchain.schema.callbacks.base import ( + AsyncCallbackHandler, + BaseCallbackHandler, + BaseCallbackManager, CallbackManagerMixin, + Callbacks, + ChainManagerMixin, + LLMManagerMixin, + RetrieverManagerMixin, RunManagerMixin, -): - """Base callback handler that handles callbacks from LangChain.""" + ToolManagerMixin, +) - raise_error: bool = False - - run_inline: bool = False - - @property - def ignore_llm(self) -> bool: - """Whether to ignore LLM callbacks.""" - return False - - @property - def ignore_retry(self) -> bool: - """Whether to ignore retry callbacks.""" - return False - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return False - - @property - def ignore_agent(self) -> bool: - """Whether to ignore agent callbacks.""" - return False - - @property - def ignore_retriever(self) -> bool: - """Whether to ignore retriever callbacks.""" - return False - - @property - def ignore_chat_model(self) -> bool: - """Whether to ignore chat model callbacks.""" - return False - - -class AsyncCallbackHandler(BaseCallbackHandler): - """Async callback handler that handles callbacks from LangChain.""" - - async def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM starts running.""" - - async def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when a chat model starts running.""" - raise NotImplementedError( - f"{self.__class__.__name__} does not implement `on_chat_model_start`" - ) - - async def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - - async def on_llm_end( - self, - response: LLMResult, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM ends running.""" - - async def on_llm_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM errors.""" - - async def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Run when chain starts running.""" - - async def on_chain_end( - self, - outputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when chain ends running.""" - - async def on_chain_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when chain errors.""" - - async def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Run when tool starts running.""" - - async def on_tool_end( - self, - output: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when tool ends running.""" - - async def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when tool errors.""" - - async def on_text( - self, - text: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on arbitrary text.""" - - async def on_retry( - self, - retry_state: RetryCallState, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on a retry event.""" - - async def on_agent_action( - self, - action: AgentAction, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on agent action.""" - - async def on_agent_finish( - self, - finish: AgentFinish, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on agent end.""" - - async def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Run on retriever start.""" - - async def on_retriever_end( - self, - documents: Sequence[Document], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on retriever end.""" - - async def on_retriever_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on retriever error.""" - - -T = TypeVar("T", bound="BaseCallbackManager") - - -class BaseCallbackManager(CallbackManagerMixin): - """Base callback manager that handles callbacks from LangChain.""" - - def __init__( - self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, - *, - tags: Optional[List[str]] = None, - inheritable_tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """Initialize callback manager.""" - self.handlers: List[BaseCallbackHandler] = handlers - self.inheritable_handlers: List[BaseCallbackHandler] = ( - inheritable_handlers or [] - ) - self.parent_run_id: Optional[UUID] = parent_run_id - self.tags = tags or [] - self.inheritable_tags = inheritable_tags or [] - self.metadata = metadata or {} - self.inheritable_metadata = inheritable_metadata or {} - - def copy(self: T) -> T: - """Copy the callback manager.""" - return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - @property - def is_async(self) -> bool: - """Whether the callback manager is async.""" - return False - - def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: - """Add a handler to the callback manager.""" - if handler not in self.handlers: - self.handlers.append(handler) - if inherit and handler not in self.inheritable_handlers: - self.inheritable_handlers.append(handler) - - def remove_handler(self, handler: BaseCallbackHandler) -> None: - """Remove a handler from the callback manager.""" - self.handlers.remove(handler) - self.inheritable_handlers.remove(handler) - - def set_handlers( - self, handlers: List[BaseCallbackHandler], inherit: bool = True - ) -> None: - """Set handlers as the only handlers on the callback manager.""" - self.handlers = [] - self.inheritable_handlers = [] - for handler in handlers: - self.add_handler(handler, inherit=inherit) - - def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: - """Set handler as the only handler on the callback manager.""" - self.set_handlers([handler], inherit=inherit) - - def add_tags(self, tags: List[str], inherit: bool = True) -> None: - for tag in tags: - if tag in self.tags: - self.remove_tags([tag]) - self.tags.extend(tags) - if inherit: - self.inheritable_tags.extend(tags) - - def remove_tags(self, tags: List[str]) -> None: - for tag in tags: - self.tags.remove(tag) - self.inheritable_tags.remove(tag) - - def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: - self.metadata.update(metadata) - if inherit: - self.inheritable_metadata.update(metadata) - - def remove_metadata(self, keys: List[str]) -> None: - for key in keys: - self.metadata.pop(key) - self.inheritable_metadata.pop(key) - - -Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] +__all__ = [ + "RetrieverManagerMixin", + "LLMManagerMixin", + "ChainManagerMixin", + "ToolManagerMixin", + "CallbackManagerMixin", + "RunManagerMixin", + "BaseCallbackHandler", + "AsyncCallbackHandler", + "BaseCallbackManager", + "Callbacks", +] diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 5abcb55dfad..cfaab32cfaf 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -1,91 +1,57 @@ from __future__ import annotations -import asyncio -import functools import logging -import os -import uuid -from concurrent.futures import ThreadPoolExecutor -from contextlib import asynccontextmanager, contextmanager +from contextlib import contextmanager from contextvars import ContextVar from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Coroutine, - Dict, Generator, - List, Optional, - Sequence, - Type, - TypeVar, - Union, - cast, ) -from uuid import UUID -from langsmith import utils as ls_utils -from langsmith.run_helpers import get_run_tree_context -from tenacity import RetryCallState - -from langchain.callbacks.base import ( - BaseCallbackHandler, - BaseCallbackManager, - Callbacks, - ChainManagerMixin, - LLMManagerMixin, - RetrieverManagerMixin, - RunManagerMixin, - ToolManagerMixin, -) from langchain.callbacks.openai_info import OpenAICallbackHandler -from langchain.callbacks.stdout import StdOutCallbackHandler -from langchain.callbacks.tracers import run_collector -from langchain.callbacks.tracers.langchain import ( - LangChainTracer, -) -from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1 -from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler from langchain.callbacks.tracers.wandb import WandbTracer -from langchain.schema import ( - AgentAction, - AgentFinish, - Document, - LLMResult, +from langchain.schema.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainGroup, + AsyncCallbackManagerForChainRun, + AsyncCallbackManagerForLLMRun, + AsyncCallbackManagerForRetrieverRun, + AsyncCallbackManagerForToolRun, + AsyncParentRunManager, + AsyncRunManager, + BaseRunManager, + CallbackManager, + CallbackManagerForChainGroup, + CallbackManagerForChainRun, + CallbackManagerForLLMRun, + CallbackManagerForRetrieverRun, + CallbackManagerForToolRun, + Callbacks, + ParentRunManager, + RunManager, + atrace_as_chain_group, + collect_runs, + env_var_is_set, + handle_event, + register_configure_hook, + trace_as_chain_group, + tracing_enabled, + tracing_v2_enabled, ) -from langchain.schema.messages import BaseMessage, get_buffer_string -from langchain.schema.output import ChatGenerationChunk, GenerationChunk - -if TYPE_CHECKING: - from langsmith import Client as LangSmithClient logger = logging.getLogger(__name__) openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( "openai_callback", default=None ) -tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501 - "tracing_callback", default=None -) wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( # noqa: E501 "tracing_wandb_callback", default=None ) -tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501 - "tracing_callback_v2", default=None +register_configure_hook(openai_callback_var, True) +register_configure_hook( + wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING" ) -run_collector_var: ContextVar[ - Optional[run_collector.RunCollectorCallbackHandler] -] = ContextVar( # noqa: E501 - "run_collector", default=None -) - - -def _get_debug() -> bool: - from langchain.globals import get_debug - - return get_debug() @contextmanager @@ -106,32 +72,6 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: openai_callback_var.set(None) -@contextmanager -def tracing_enabled( - session_name: str = "default", -) -> Generator[TracerSessionV1, None, None]: - """Get the Deprecated LangChainTracer in a context manager. - - Args: - session_name (str, optional): The name of the session. - Defaults to "default". - - Returns: - TracerSessionV1: The LangChainTracer session. - - Example: - >>> with tracing_enabled() as session: - ... # Use the LangChainTracer session - """ - cb = LangChainTracerV1() - session = cast(TracerSessionV1, cb.load_session(session_name)) - try: - tracing_callback_var.set(cb) - yield session - finally: - tracing_callback_var.set(None) - - @contextmanager def wandb_tracing_enabled( session_name: str = "default", @@ -155,1940 +95,30 @@ def wandb_tracing_enabled( wandb_tracing_callback_var.set(None) -@contextmanager -def tracing_v2_enabled( - project_name: Optional[str] = None, - *, - example_id: Optional[Union[str, UUID]] = None, - tags: Optional[List[str]] = None, - client: Optional[LangSmithClient] = None, -) -> Generator[LangChainTracer, None, None]: - """Instruct LangChain to log all runs in context to LangSmith. - - Args: - project_name (str, optional): The name of the project. - Defaults to "default". - example_id (str or UUID, optional): The ID of the example. - Defaults to None. - tags (List[str], optional): The tags to add to the run. - Defaults to None. - - Returns: - None - - Example: - >>> with tracing_v2_enabled(): - ... # LangChain code will automatically be traced - - You can use this to fetch the LangSmith run URL: - - >>> with tracing_v2_enabled() as cb: - ... chain.invoke("foo") - ... run_url = cb.get_run_url() - """ - if isinstance(example_id, str): - example_id = UUID(example_id) - cb = LangChainTracer( - example_id=example_id, - project_name=project_name, - tags=tags, - client=client, - ) - try: - tracing_v2_callback_var.set(cb) - yield cb - finally: - tracing_v2_callback_var.set(None) - - -@contextmanager -def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]: - """Collect all run traces in context. - - Returns: - run_collector.RunCollectorCallbackHandler: The run collector callback handler. - - Example: - >>> with collect_runs() as runs_cb: - chain.invoke("foo") - run_id = runs_cb.traced_runs[0].id - """ - cb = run_collector.RunCollectorCallbackHandler() - run_collector_var.set(cb) - yield cb - run_collector_var.set(None) - - -def _get_trace_callbacks( - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None, -) -> Callbacks: - if _tracing_v2_is_enabled(): - project_name_ = project_name or _get_tracer_project() - tracer = tracing_v2_callback_var.get() or LangChainTracer( - project_name=project_name_, - example_id=example_id, - ) - if callback_manager is None: - cb = cast(Callbacks, [tracer]) - else: - if not any( - isinstance(handler, LangChainTracer) - for handler in callback_manager.handlers - ): - callback_manager.add_handler(tracer, True) - # If it already has a LangChainTracer, we don't need to add another one. - # this would likely mess up the trace hierarchy. - cb = callback_manager - else: - cb = None - return cb - - -@contextmanager -def trace_as_chain_group( - group_name: str, - callback_manager: Optional[CallbackManager] = None, - *, - inputs: Optional[Dict[str, Any]] = None, - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, -) -> Generator[CallbackManagerForChainGroup, None, None]: - """Get a callback manager for a chain group in a context manager. - Useful for grouping different calls together as a single run even if - they aren't composed in a single chain. - - Args: - group_name (str): The name of the chain group. - callback_manager (CallbackManager, optional): The callback manager to use. - inputs (Dict[str, Any], optional): The inputs to the chain group. - project_name (str, optional): The name of the project. - Defaults to None. - example_id (str or UUID, optional): The ID of the example. - Defaults to None. - run_id (UUID, optional): The ID of the run. - tags (List[str], optional): The inheritable tags to apply to all runs. - Defaults to None. - - Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. - - Returns: - CallbackManagerForChainGroup: The callback manager for the chain group. - - Example: - .. code-block:: python - - llm_input = "Foo" - with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: - # Use the callback manager for the chain group - res = llm.predict(llm_input, callbacks=manager) - manager.on_chain_end({"output": res}) - """ # noqa: E501 - cb = _get_trace_callbacks( - project_name, example_id, callback_manager=callback_manager - ) - cm = CallbackManager.configure( - inheritable_callbacks=cb, - inheritable_tags=tags, - ) - - run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id) - child_cm = run_manager.get_child() - group_cm = CallbackManagerForChainGroup( - child_cm.handlers, - child_cm.inheritable_handlers, - child_cm.parent_run_id, - parent_run_manager=run_manager, - tags=child_cm.tags, - inheritable_tags=child_cm.inheritable_tags, - metadata=child_cm.metadata, - inheritable_metadata=child_cm.inheritable_metadata, - ) - try: - yield group_cm - except Exception as e: - if not group_cm.ended: - run_manager.on_chain_error(e) - raise e - else: - if not group_cm.ended: - run_manager.on_chain_end({}) - - -@asynccontextmanager -async def atrace_as_chain_group( - group_name: str, - callback_manager: Optional[AsyncCallbackManager] = None, - *, - inputs: Optional[Dict[str, Any]] = None, - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, -) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: - """Get an async callback manager for a chain group in a context manager. - Useful for grouping different async calls together as a single run even if - they aren't composed in a single chain. - - Args: - group_name (str): The name of the chain group. - callback_manager (AsyncCallbackManager, optional): The async callback manager to use, - which manages tracing and other callback behavior. - project_name (str, optional): The name of the project. - Defaults to None. - example_id (str or UUID, optional): The ID of the example. - Defaults to None. - run_id (UUID, optional): The ID of the run. - tags (List[str], optional): The inheritable tags to apply to all runs. - Defaults to None. - Returns: - AsyncCallbackManager: The async callback manager for the chain group. - - Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. - - Example: - .. code-block:: python - - llm_input = "Foo" - async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: - # Use the async callback manager for the chain group - res = await llm.apredict(llm_input, callbacks=manager) - await manager.on_chain_end({"output": res}) - """ # noqa: E501 - cb = _get_trace_callbacks( - project_name, example_id, callback_manager=callback_manager - ) - cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) - - run_manager = await cm.on_chain_start( - {"name": group_name}, inputs or {}, run_id=run_id - ) - child_cm = run_manager.get_child() - group_cm = AsyncCallbackManagerForChainGroup( - child_cm.handlers, - child_cm.inheritable_handlers, - child_cm.parent_run_id, - parent_run_manager=run_manager, - tags=child_cm.tags, - inheritable_tags=child_cm.inheritable_tags, - metadata=child_cm.metadata, - inheritable_metadata=child_cm.inheritable_metadata, - ) - try: - yield group_cm - except Exception as e: - if not group_cm.ended: - await run_manager.on_chain_error(e) - raise e - else: - if not group_cm.ended: - await run_manager.on_chain_end({}) - - -def handle_event( - handlers: List[BaseCallbackHandler], - event_name: str, - ignore_condition_name: Optional[str], - *args: Any, - **kwargs: Any, -) -> None: - """Generic event handler for CallbackManager. - - Note: This function is used by langserve to handle events. - - Args: - handlers: The list of handlers that will handle the event - event_name: The name of the event (e.g., "on_llm_start") - ignore_condition_name: Name of the attribute defined on handler - that if True will cause the handler to be skipped for the given event - *args: The arguments to pass to the event handler - **kwargs: The keyword arguments to pass to the event handler - """ - coros: List[Coroutine[Any, Any, Any]] = [] - - try: - message_strings: Optional[List[str]] = None - for handler in handlers: - try: - if ignore_condition_name is None or not getattr( - handler, ignore_condition_name - ): - event = getattr(handler, event_name)(*args, **kwargs) - if asyncio.iscoroutine(event): - coros.append(event) - except NotImplementedError as e: - if event_name == "on_chat_model_start": - if message_strings is None: - message_strings = [get_buffer_string(m) for m in args[1]] - handle_event( - [handler], - "on_llm_start", - "ignore_llm", - args[0], - message_strings, - *args[2:], - **kwargs, - ) - else: - handler_name = handler.__class__.__name__ - logger.warning( - f"NotImplementedError in {handler_name}.{event_name}" - f" callback: {repr(e)}" - ) - except Exception as e: - logger.warning( - f"Error in {handler.__class__.__name__}.{event_name} callback:" - f" {repr(e)}" - ) - if handler.raise_error: - raise e - finally: - if coros: - try: - # Raises RuntimeError if there is no current event loop. - asyncio.get_running_loop() - loop_running = True - except RuntimeError: - loop_running = False - - if loop_running: - # If we try to submit this coroutine to the running loop - # we end up in a deadlock, as we'd have gotten here from a - # running coroutine, which we cannot interrupt to run this one. - # The solution is to create a new loop in a new thread. - with ThreadPoolExecutor(1) as executor: - executor.submit(_run_coros, coros).result() - else: - _run_coros(coros) - - -def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: - if hasattr(asyncio, "Runner"): - # Python 3.11+ - # Run the coroutines in a new event loop, taking care to - # - install signal handlers - # - run pending tasks scheduled by `coros` - # - close asyncgens and executors - # - close the loop - with asyncio.Runner() as runner: - # Run the coroutine, get the result - for coro in coros: - runner.run(coro) - - # Run pending tasks scheduled by coros until they are all done - while pending := asyncio.all_tasks(runner.get_loop()): - runner.run(asyncio.wait(pending)) - else: - # Before Python 3.11 we need to run each coroutine in a new event loop - # as the Runner api is not available. - for coro in coros: - asyncio.run(coro) - - -async def _ahandle_event_for_handler( - handler: BaseCallbackHandler, - event_name: str, - ignore_condition_name: Optional[str], - *args: Any, - **kwargs: Any, -) -> None: - try: - if ignore_condition_name is None or not getattr(handler, ignore_condition_name): - event = getattr(handler, event_name) - if asyncio.iscoroutinefunction(event): - await event(*args, **kwargs) - else: - if handler.run_inline: - event(*args, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, functools.partial(event, *args, **kwargs) - ) - except NotImplementedError as e: - if event_name == "on_chat_model_start": - message_strings = [get_buffer_string(m) for m in args[1]] - await _ahandle_event_for_handler( - handler, - "on_llm_start", - "ignore_llm", - args[0], - message_strings, - *args[2:], - **kwargs, - ) - else: - logger.warning( - f"NotImplementedError in {handler.__class__.__name__}.{event_name}" - f" callback: {repr(e)}" - ) - except Exception as e: - logger.warning( - f"Error in {handler.__class__.__name__}.{event_name} callback:" - f" {repr(e)}" - ) - if handler.raise_error: - raise e - - -async def ahandle_event( - handlers: List[BaseCallbackHandler], - event_name: str, - ignore_condition_name: Optional[str], - *args: Any, - **kwargs: Any, -) -> None: - """Generic event handler for AsyncCallbackManager. - - Note: This function is used by langserve to handle events. - - Args: - handlers: The list of handlers that will handle the event - event_name: The name of the event (e.g., "on_llm_start") - ignore_condition_name: Name of the attribute defined on handler - that if True will cause the handler to be skipped for the given event - *args: The arguments to pass to the event handler - **kwargs: The keyword arguments to pass to the event handler - """ - for handler in [h for h in handlers if h.run_inline]: - await _ahandle_event_for_handler( - handler, event_name, ignore_condition_name, *args, **kwargs - ) - await asyncio.gather( - *( - _ahandle_event_for_handler( - handler, event_name, ignore_condition_name, *args, **kwargs - ) - for handler in handlers - if not handler.run_inline - ) - ) - - -BRM = TypeVar("BRM", bound="BaseRunManager") - - -class BaseRunManager(RunManagerMixin): - """Base class for run manager (a bound callback manager).""" - - def __init__( - self, - *, - run_id: UUID, - handlers: List[BaseCallbackHandler], - inheritable_handlers: List[BaseCallbackHandler], - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - inheritable_tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """Initialize the run manager. - - Args: - run_id (UUID): The ID of the run. - handlers (List[BaseCallbackHandler]): The list of handlers. - inheritable_handlers (List[BaseCallbackHandler]): - The list of inheritable handlers. - parent_run_id (UUID, optional): The ID of the parent run. - Defaults to None. - tags (Optional[List[str]]): The list of tags. - inheritable_tags (Optional[List[str]]): The list of inheritable tags. - metadata (Optional[Dict[str, Any]]): The metadata. - inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata. - """ - self.run_id = run_id - self.handlers = handlers - self.inheritable_handlers = inheritable_handlers - self.parent_run_id = parent_run_id - self.tags = tags or [] - self.inheritable_tags = inheritable_tags or [] - self.metadata = metadata or {} - self.inheritable_metadata = inheritable_metadata or {} - - @classmethod - def get_noop_manager(cls: Type[BRM]) -> BRM: - """Return a manager that doesn't perform any operations. - - Returns: - BaseRunManager: The noop manager. - """ - return cls( - run_id=uuid.uuid4(), - handlers=[], - inheritable_handlers=[], - tags=[], - inheritable_tags=[], - metadata={}, - inheritable_metadata={}, - ) - - -class RunManager(BaseRunManager): - """Sync Run Manager.""" - - def on_text( - self, - text: str, - **kwargs: Any, - ) -> Any: - """Run when text is received. - - Args: - text (str): The received text. - - Returns: - Any: The result of the callback. - """ - handle_event( - self.handlers, - "on_text", - None, - text, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_retry( - self, - retry_state: RetryCallState, - **kwargs: Any, - ) -> None: - handle_event( - self.handlers, - "on_retry", - "ignore_retry", - retry_state, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class ParentRunManager(RunManager): - """Sync Parent Run Manager.""" - - def get_child(self, tag: Optional[str] = None) -> CallbackManager: - """Get a child callback manager. - - Args: - tag (str, optional): The tag for the child callback manager. - Defaults to None. - - Returns: - CallbackManager: The child callback manager. - """ - manager = CallbackManager(handlers=[], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - manager.add_metadata(self.inheritable_metadata) - if tag is not None: - manager.add_tags([tag], False) - return manager - - -class AsyncRunManager(BaseRunManager): - """Async Run Manager.""" - - async def on_text( - self, - text: str, - **kwargs: Any, - ) -> Any: - """Run when text is received. - - Args: - text (str): The received text. - - Returns: - Any: The result of the callback. - """ - await ahandle_event( - self.handlers, - "on_text", - None, - text, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_retry( - self, - retry_state: RetryCallState, - **kwargs: Any, - ) -> None: - await ahandle_event( - self.handlers, - "on_retry", - "ignore_retry", - retry_state, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncParentRunManager(AsyncRunManager): - """Async Parent Run Manager.""" - - def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: - """Get a child callback manager. - - Args: - tag (str, optional): The tag for the child callback manager. - Defaults to None. - - Returns: - AsyncCallbackManager: The child callback manager. - """ - manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - manager.add_metadata(self.inheritable_metadata) - if tag is not None: - manager.add_tags([tag], False) - return manager - - -class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): - """Callback manager for LLM run.""" - - def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM generates a new token. - - Args: - token (str): The new token. - """ - handle_event( - self.handlers, - "on_llm_new_token", - "ignore_llm", - token=token, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - chunk=chunk, - **kwargs, - ) - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running. - - Args: - response (LLMResult): The LLM result. - """ - handle_event( - self.handlers, - "on_llm_end", - "ignore_llm", - response, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_llm_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when LLM errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - handle_event( - self.handlers, - "on_llm_error", - "ignore_llm", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): - """Async callback manager for LLM run.""" - - async def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM generates a new token. - - Args: - token (str): The new token. - """ - await ahandle_event( - self.handlers, - "on_llm_new_token", - "ignore_llm", - token, - chunk=chunk, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running. - - Args: - response (LLMResult): The LLM result. - """ - await ahandle_event( - self.handlers, - "on_llm_end", - "ignore_llm", - response, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_llm_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when LLM errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - await ahandle_event( - self.handlers, - "on_llm_error", - "ignore_llm", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): - """Callback manager for chain run.""" - - def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: - """Run when chain ends running. - - Args: - outputs (Union[Dict[str, Any], Any]): The outputs of the chain. - """ - handle_event( - self.handlers, - "on_chain_end", - "ignore_chain", - outputs, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_chain_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when chain errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - handle_event( - self.handlers, - "on_chain_error", - "ignore_chain", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run when agent action is received. - - Args: - action (AgentAction): The agent action. - - Returns: - Any: The result of the callback. - """ - handle_event( - self.handlers, - "on_agent_action", - "ignore_agent", - action, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run when agent finish is received. - - Args: - finish (AgentFinish): The agent finish. - - Returns: - Any: The result of the callback. - """ - handle_event( - self.handlers, - "on_agent_finish", - "ignore_agent", - finish, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): - """Async callback manager for chain run.""" - - async def on_chain_end( - self, outputs: Union[Dict[str, Any], Any], **kwargs: Any - ) -> None: - """Run when chain ends running. - - Args: - outputs (Union[Dict[str, Any], Any]): The outputs of the chain. - """ - await ahandle_event( - self.handlers, - "on_chain_end", - "ignore_chain", - outputs, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_chain_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when chain errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - await ahandle_event( - self.handlers, - "on_chain_error", - "ignore_chain", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run when agent action is received. - - Args: - action (AgentAction): The agent action. - - Returns: - Any: The result of the callback. - """ - await ahandle_event( - self.handlers, - "on_agent_action", - "ignore_agent", - action, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run when agent finish is received. - - Args: - finish (AgentFinish): The agent finish. - - Returns: - Any: The result of the callback. - """ - await ahandle_event( - self.handlers, - "on_agent_finish", - "ignore_agent", - finish, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): - """Callback manager for tool run.""" - - def on_tool_end( - self, - output: str, - **kwargs: Any, - ) -> None: - """Run when tool ends running. - - Args: - output (str): The output of the tool. - """ - handle_event( - self.handlers, - "on_tool_end", - "ignore_agent", - output, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_tool_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when tool errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - handle_event( - self.handlers, - "on_tool_error", - "ignore_agent", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): - """Async callback manager for tool run.""" - - async def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running. - - Args: - output (str): The output of the tool. - """ - await ahandle_event( - self.handlers, - "on_tool_end", - "ignore_agent", - output, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_tool_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when tool errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - await ahandle_event( - self.handlers, - "on_tool_error", - "ignore_agent", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): - """Callback manager for retriever run.""" - - def on_retriever_end( - self, - documents: Sequence[Document], - **kwargs: Any, - ) -> None: - """Run when retriever ends running.""" - handle_event( - self.handlers, - "on_retriever_end", - "ignore_retriever", - documents, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_retriever_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when retriever errors.""" - handle_event( - self.handlers, - "on_retriever_error", - "ignore_retriever", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncCallbackManagerForRetrieverRun( - AsyncParentRunManager, - RetrieverManagerMixin, -): - """Async callback manager for retriever run.""" - - async def on_retriever_end( - self, documents: Sequence[Document], **kwargs: Any - ) -> None: - """Run when retriever ends running.""" - await ahandle_event( - self.handlers, - "on_retriever_end", - "ignore_retriever", - documents, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_retriever_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when retriever errors.""" - await ahandle_event( - self.handlers, - "on_retriever_error", - "ignore_retriever", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class CallbackManager(BaseCallbackManager): - """Callback manager that handles callbacks from LangChain.""" - - def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - **kwargs: Any, - ) -> List[CallbackManagerForLLMRun]: - """Run when LLM starts running. - - Args: - serialized (Dict[str, Any]): The serialized LLM. - prompts (List[str]): The list of prompts. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - List[CallbackManagerForLLMRun]: A callback manager for each - prompt as an LLM run. - """ - managers = [] - for prompt in prompts: - run_id_ = uuid.uuid4() - handle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - [prompt], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - managers.append( - CallbackManagerForLLMRun( - run_id=run_id_, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - ) - - return managers - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - **kwargs: Any, - ) -> List[CallbackManagerForLLMRun]: - """Run when LLM starts running. - - Args: - serialized (Dict[str, Any]): The serialized LLM. - messages (List[List[BaseMessage]]): The list of messages. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - List[CallbackManagerForLLMRun]: A callback manager for each - list of messages as an LLM run. - """ - - managers = [] - for message_list in messages: - run_id_ = uuid.uuid4() - handle_event( - self.handlers, - "on_chat_model_start", - "ignore_chat_model", - serialized, - [message_list], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - managers.append( - CallbackManagerForLLMRun( - run_id=run_id_, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - ) - - return managers - - def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Union[Dict[str, Any], Any], - run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> CallbackManagerForChainRun: - """Run when chain starts running. - - Args: - serialized (Dict[str, Any]): The serialized chain. - inputs (Union[Dict[str, Any], Any]): The inputs to the chain. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - CallbackManagerForChainRun: The callback manager for the chain run. - """ - if run_id is None: - run_id = uuid.uuid4() - handle_event( - self.handlers, - "on_chain_start", - "ignore_chain", - serialized, - inputs, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return CallbackManagerForChainRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> CallbackManagerForToolRun: - """Run when tool starts running. - - Args: - serialized (Dict[str, Any]): The serialized tool. - input_str (str): The input to the tool. - run_id (UUID, optional): The ID of the run. Defaults to None. - parent_run_id (UUID, optional): The ID of the parent run. Defaults to None. - - Returns: - CallbackManagerForToolRun: The callback manager for the tool run. - """ - if run_id is None: - run_id = uuid.uuid4() - - handle_event( - self.handlers, - "on_tool_start", - "ignore_agent", - serialized, - input_str, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return CallbackManagerForToolRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> CallbackManagerForRetrieverRun: - """Run when retriever starts running.""" - if run_id is None: - run_id = uuid.uuid4() - - handle_event( - self.handlers, - "on_retriever_start", - "ignore_retriever", - serialized, - query, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return CallbackManagerForRetrieverRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - @classmethod - def configure( - cls, - inheritable_callbacks: Callbacks = None, - local_callbacks: Callbacks = None, - verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, - ) -> CallbackManager: - """Configure the callback manager. - - Args: - inheritable_callbacks (Optional[Callbacks], optional): The inheritable - callbacks. Defaults to None. - local_callbacks (Optional[Callbacks], optional): The local callbacks. - Defaults to None. - verbose (bool, optional): Whether to enable verbose mode. Defaults to False. - inheritable_tags (Optional[List[str]], optional): The inheritable tags. - Defaults to None. - local_tags (Optional[List[str]], optional): The local tags. - Defaults to None. - inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable - metadata. Defaults to None. - local_metadata (Optional[Dict[str, Any]], optional): The local metadata. - Defaults to None. - - Returns: - CallbackManager: The configured callback manager. - """ - return _configure( - cls, - inheritable_callbacks, - local_callbacks, - verbose, - inheritable_tags, - local_tags, - inheritable_metadata, - local_metadata, - ) - - -class CallbackManagerForChainGroup(CallbackManager): - """Callback manager for the chain group.""" - - def __init__( - self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, - *, - parent_run_manager: CallbackManagerForChainRun, - **kwargs: Any, - ) -> None: - super().__init__( - handlers, - inheritable_handlers, - parent_run_id, - **kwargs, - ) - self.parent_run_manager = parent_run_manager - self.ended = False - - def copy(self) -> CallbackManagerForChainGroup: - return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - parent_run_manager=self.parent_run_manager, - ) - - def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: - """Run when traced chain group ends. - - Args: - outputs (Union[Dict[str, Any], Any]): The outputs of the chain. - """ - self.ended = True - return self.parent_run_manager.on_chain_end(outputs, **kwargs) - - def on_chain_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when chain errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - self.ended = True - return self.parent_run_manager.on_chain_error(error, **kwargs) - - -class AsyncCallbackManager(BaseCallbackManager): - """Async callback manager that handles callbacks from LangChain.""" - - @property - def is_async(self) -> bool: - """Return whether the handler is async.""" - return True - - async def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - **kwargs: Any, - ) -> List[AsyncCallbackManagerForLLMRun]: - """Run when LLM starts running. - - Args: - serialized (Dict[str, Any]): The serialized LLM. - prompts (List[str]): The list of prompts. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - List[AsyncCallbackManagerForLLMRun]: The list of async - callback managers, one for each LLM Run corresponding - to each prompt. - """ - - tasks = [] - managers = [] - - for prompt in prompts: - run_id_ = uuid.uuid4() - - tasks.append( - ahandle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - [prompt], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - ) - - managers.append( - AsyncCallbackManagerForLLMRun( - run_id=run_id_, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - ) - - await asyncio.gather(*tasks) - - return managers - - async def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - **kwargs: Any, - ) -> List[AsyncCallbackManagerForLLMRun]: - """Run when LLM starts running. - - Args: - serialized (Dict[str, Any]): The serialized LLM. - messages (List[List[BaseMessage]]): The list of messages. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - List[AsyncCallbackManagerForLLMRun]: The list of - async callback managers, one for each LLM Run - corresponding to each inner message list. - """ - tasks = [] - managers = [] - - for message_list in messages: - run_id_ = uuid.uuid4() - - tasks.append( - ahandle_event( - self.handlers, - "on_chat_model_start", - "ignore_chat_model", - serialized, - [message_list], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - ) - - managers.append( - AsyncCallbackManagerForLLMRun( - run_id=run_id_, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - ) - - await asyncio.gather(*tasks) - return managers - - async def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Union[Dict[str, Any], Any], - run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> AsyncCallbackManagerForChainRun: - """Run when chain starts running. - - Args: - serialized (Dict[str, Any]): The serialized chain. - inputs (Union[Dict[str, Any], Any]): The inputs to the chain. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - AsyncCallbackManagerForChainRun: The async callback manager - for the chain run. - """ - if run_id is None: - run_id = uuid.uuid4() - - await ahandle_event( - self.handlers, - "on_chain_start", - "ignore_chain", - serialized, - inputs, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return AsyncCallbackManagerForChainRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - async def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> AsyncCallbackManagerForToolRun: - """Run when tool starts running. - - Args: - serialized (Dict[str, Any]): The serialized tool. - input_str (str): The input to the tool. - run_id (UUID, optional): The ID of the run. Defaults to None. - parent_run_id (UUID, optional): The ID of the parent run. - Defaults to None. - - Returns: - AsyncCallbackManagerForToolRun: The async callback manager - for the tool run. - """ - if run_id is None: - run_id = uuid.uuid4() - - await ahandle_event( - self.handlers, - "on_tool_start", - "ignore_agent", - serialized, - input_str, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return AsyncCallbackManagerForToolRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - async def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> AsyncCallbackManagerForRetrieverRun: - """Run when retriever starts running.""" - if run_id is None: - run_id = uuid.uuid4() - - await ahandle_event( - self.handlers, - "on_retriever_start", - "ignore_retriever", - serialized, - query, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return AsyncCallbackManagerForRetrieverRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - @classmethod - def configure( - cls, - inheritable_callbacks: Callbacks = None, - local_callbacks: Callbacks = None, - verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, - ) -> AsyncCallbackManager: - """Configure the async callback manager. - - Args: - inheritable_callbacks (Optional[Callbacks], optional): The inheritable - callbacks. Defaults to None. - local_callbacks (Optional[Callbacks], optional): The local callbacks. - Defaults to None. - verbose (bool, optional): Whether to enable verbose mode. Defaults to False. - inheritable_tags (Optional[List[str]], optional): The inheritable tags. - Defaults to None. - local_tags (Optional[List[str]], optional): The local tags. - Defaults to None. - inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable - metadata. Defaults to None. - local_metadata (Optional[Dict[str, Any]], optional): The local metadata. - Defaults to None. - - Returns: - AsyncCallbackManager: The configured async callback manager. - """ - return _configure( - cls, - inheritable_callbacks, - local_callbacks, - verbose, - inheritable_tags, - local_tags, - inheritable_metadata, - local_metadata, - ) - - -class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): - """Async callback manager for the chain group.""" - - def __init__( - self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, - *, - parent_run_manager: AsyncCallbackManagerForChainRun, - **kwargs: Any, - ) -> None: - super().__init__( - handlers, - inheritable_handlers, - parent_run_id, - **kwargs, - ) - self.parent_run_manager = parent_run_manager - self.ended = False - - def copy(self) -> AsyncCallbackManagerForChainGroup: - return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - parent_run_manager=self.parent_run_manager, - ) - - async def on_chain_end( - self, outputs: Union[Dict[str, Any], Any], **kwargs: Any - ) -> None: - """Run when traced chain group ends. - - Args: - outputs (Union[Dict[str, Any], Any]): The outputs of the chain. - """ - self.ended = True - await self.parent_run_manager.on_chain_end(outputs, **kwargs) - - async def on_chain_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when chain errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - self.ended = True - await self.parent_run_manager.on_chain_error(error, **kwargs) - - -T = TypeVar("T", CallbackManager, AsyncCallbackManager) - - -def env_var_is_set(env_var: str) -> bool: - """Check if an environment variable is set. - - Args: - env_var (str): The name of the environment variable. - - Returns: - bool: True if the environment variable is set, False otherwise. - """ - return env_var in os.environ and os.environ[env_var] not in ( - "", - "0", - "false", - "False", - ) - - -def _tracing_v2_is_enabled() -> bool: - return ( - env_var_is_set("LANGCHAIN_TRACING_V2") - or tracing_v2_callback_var.get() is not None - or get_run_tree_context() is not None - ) - - -def _get_tracer_project() -> str: - run_tree = get_run_tree_context() - return getattr( - run_tree, - "session_name", - getattr( - # Note, if people are trying to nest @traceable functions and the - # tracing_v2_enabled context manager, this will likely mess up the - # tree structure. - tracing_v2_callback_var.get(), - "project", - # Have to set this to a string even though it always will return - # a string because `get_tracer_project` technically can return - # None, but only when a specific argument is supplied. - # Therefore, this just tricks the mypy type checker - str(ls_utils.get_tracer_project()), - ), - ) - - -def _configure( - callback_manager_cls: Type[T], - inheritable_callbacks: Callbacks = None, - local_callbacks: Callbacks = None, - verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, -) -> T: - """Configure the callback manager. - - Args: - callback_manager_cls (Type[T]): The callback manager class. - inheritable_callbacks (Optional[Callbacks], optional): The inheritable - callbacks. Defaults to None. - local_callbacks (Optional[Callbacks], optional): The local callbacks. - Defaults to None. - verbose (bool, optional): Whether to enable verbose mode. Defaults to False. - inheritable_tags (Optional[List[str]], optional): The inheritable tags. - Defaults to None. - local_tags (Optional[List[str]], optional): The local tags. Defaults to None. - inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable - metadata. Defaults to None. - local_metadata (Optional[Dict[str, Any]], optional): The local metadata. - Defaults to None. - - Returns: - T: The configured callback manager. - """ - run_tree = get_run_tree_context() - parent_run_id = None if run_tree is None else getattr(run_tree, "id") - callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id) - if inheritable_callbacks or local_callbacks: - if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: - inheritable_callbacks_ = inheritable_callbacks or [] - callback_manager = callback_manager_cls( - handlers=inheritable_callbacks_.copy(), - inheritable_handlers=inheritable_callbacks_.copy(), - parent_run_id=parent_run_id, - ) - else: - callback_manager = callback_manager_cls( - handlers=inheritable_callbacks.handlers.copy(), - inheritable_handlers=inheritable_callbacks.inheritable_handlers.copy(), - parent_run_id=inheritable_callbacks.parent_run_id, - tags=inheritable_callbacks.tags.copy(), - inheritable_tags=inheritable_callbacks.inheritable_tags.copy(), - metadata=inheritable_callbacks.metadata.copy(), - inheritable_metadata=inheritable_callbacks.inheritable_metadata.copy(), - ) - local_handlers_ = ( - local_callbacks - if isinstance(local_callbacks, list) - else (local_callbacks.handlers if local_callbacks else []) - ) - for handler in local_handlers_: - callback_manager.add_handler(handler, False) - if inheritable_tags or local_tags: - callback_manager.add_tags(inheritable_tags or []) - callback_manager.add_tags(local_tags or [], False) - if inheritable_metadata or local_metadata: - callback_manager.add_metadata(inheritable_metadata or {}) - callback_manager.add_metadata(local_metadata or {}, False) - - tracer = tracing_callback_var.get() - wandb_tracer = wandb_tracing_callback_var.get() - open_ai = openai_callback_var.get() - tracing_enabled_ = ( - env_var_is_set("LANGCHAIN_TRACING") - or tracer is not None - or env_var_is_set("LANGCHAIN_HANDLER") - ) - wandb_tracing_enabled_ = ( - env_var_is_set("LANGCHAIN_WANDB_TRACING") or wandb_tracer is not None - ) - - tracer_v2 = tracing_v2_callback_var.get() - tracing_v2_enabled_ = _tracing_v2_is_enabled() - tracer_project = _get_tracer_project() - run_collector_ = run_collector_var.get() - debug = _get_debug() - if ( - verbose - or debug - or tracing_enabled_ - or tracing_v2_enabled_ - or wandb_tracing_enabled_ - or open_ai is not None - ): - if verbose and not any( - isinstance(handler, StdOutCallbackHandler) - for handler in callback_manager.handlers - ): - if debug: - pass - else: - callback_manager.add_handler(StdOutCallbackHandler(), False) - if debug and not any( - isinstance(handler, ConsoleCallbackHandler) - for handler in callback_manager.handlers - ): - callback_manager.add_handler(ConsoleCallbackHandler(), True) - if tracing_enabled_ and not any( - isinstance(handler, LangChainTracerV1) - for handler in callback_manager.handlers - ): - if tracer: - callback_manager.add_handler(tracer, True) - else: - handler = LangChainTracerV1() - handler.load_session(tracer_project) - callback_manager.add_handler(handler, True) - if wandb_tracing_enabled_ and not any( - isinstance(handler, WandbTracer) for handler in callback_manager.handlers - ): - if wandb_tracer: - callback_manager.add_handler(wandb_tracer, True) - else: - handler = WandbTracer() - callback_manager.add_handler(handler, True) - if tracing_v2_enabled_ and not any( - isinstance(handler, LangChainTracer) - for handler in callback_manager.handlers - ): - if tracer_v2: - callback_manager.add_handler(tracer_v2, True) - else: - try: - handler = LangChainTracer(project_name=tracer_project) - callback_manager.add_handler(handler, True) - except Exception as e: - logger.warning( - "Unable to load requested LangChainTracer." - " To disable this warning," - " unset the LANGCHAIN_TRACING_V2 environment variables.", - e, - ) - if open_ai is not None and not any( - handler is open_ai # direct pointer comparison - for handler in callback_manager.handlers - ): - callback_manager.add_handler(open_ai, True) - if run_collector_ is not None and not any( - handler is run_collector_ # direct pointer comparison - for handler in callback_manager.handlers - ): - callback_manager.add_handler(run_collector_, False) - return callback_manager +__all__ = [ + "BaseRunManager", + "RunManager", + "ParentRunManager", + "AsyncRunManager", + "AsyncParentRunManager", + "CallbackManagerForLLMRun", + "AsyncCallbackManagerForLLMRun", + "CallbackManagerForChainRun", + "AsyncCallbackManagerForChainRun", + "CallbackManagerForToolRun", + "AsyncCallbackManagerForToolRun", + "CallbackManagerForRetrieverRun", + "AsyncCallbackManagerForRetrieverRun", + "CallbackManager", + "CallbackManagerForChainGroup", + "AsyncCallbackManager", + "AsyncCallbackManagerForChainGroup", + "tracing_enabled", + "tracing_v2_enabled", + "collect_runs", + "atrace_as_chain_group", + "trace_as_chain_group", + "handle_event", + "env_var_is_set", + "Callbacks", +] diff --git a/libs/langchain/langchain/callbacks/stdout.py b/libs/langchain/langchain/callbacks/stdout.py index a9738c9bf94..ef3bab0618e 100644 --- a/libs/langchain/langchain/callbacks/stdout.py +++ b/libs/langchain/langchain/callbacks/stdout.py @@ -1,97 +1,3 @@ -"""Callback Handler that prints to std out.""" -from typing import Any, Dict, List, Optional +from langchain.schema.callbacks.stdout import StdOutCallbackHandler -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult -from langchain.utils.input import print_text - - -class StdOutCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - - def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" - self.color = color - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Print out the prompts.""" - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - class_name = serialized.get("name", serialized.get("id", [""])[-1]) - print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - print("\n\033[1m> Finished chain.\033[0m") - - def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - print_text(action.log, color=color or self.color) - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - if observation_prefix is not None: - print_text(f"\n{observation_prefix}") - print_text(output, color=color or self.color) - if llm_prefix is not None: - print_text(f"\n{llm_prefix}") - - def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Any, - ) -> None: - """Run when agent ends.""" - print_text(text, color=color or self.color, end=end) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - print_text(finish.log, color=color or self.color, end="\n") +__all__ = ["StdOutCallbackHandler"] diff --git a/libs/langchain/langchain/callbacks/tracers/__init__.py b/libs/langchain/langchain/callbacks/tracers/__init__.py index 65a5a846ca2..e33002bb306 100644 --- a/libs/langchain/langchain/callbacks/tracers/__init__.py +++ b/libs/langchain/langchain/callbacks/tracers/__init__.py @@ -1,12 +1,12 @@ """Tracers that record execution of LangChain runs.""" -from langchain.callbacks.tracers.langchain import LangChainTracer -from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1 -from langchain.callbacks.tracers.stdout import ( +from langchain.callbacks.tracers.wandb import WandbTracer +from langchain.schema.callbacks.tracers.langchain import LangChainTracer +from langchain.schema.callbacks.tracers.langchain_v1 import LangChainTracerV1 +from langchain.schema.callbacks.tracers.stdout import ( ConsoleCallbackHandler, FunctionCallbackHandler, ) -from langchain.callbacks.tracers.wandb import WandbTracer __all__ = [ "LangChainTracer", diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index 10fd1d701fa..34946ff374b 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -1,537 +1,5 @@ """Base interfaces for tracing runs.""" -from __future__ import annotations -import logging -from abc import ABC, abstractmethod -from datetime import datetime -from typing import Any, Dict, List, Optional, Sequence, Union, cast -from uuid import UUID +from langchain.schema.callbacks.tracers.base import BaseTracer, TracerException -from tenacity import RetryCallState - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.callbacks.tracers.schemas import Run -from langchain.load.dump import dumpd -from langchain.schema.document import Document -from langchain.schema.output import ( - ChatGeneration, - ChatGenerationChunk, - GenerationChunk, - LLMResult, -) - -logger = logging.getLogger(__name__) - - -class TracerException(Exception): - """Base class for exceptions in tracers module.""" - - -class BaseTracer(BaseCallbackHandler, ABC): - """Base interface for tracers.""" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.run_map: Dict[str, Run] = {} - - @staticmethod - def _add_child_run( - parent_run: Run, - child_run: Run, - ) -> None: - """Add child run to a chain run or tool run.""" - parent_run.child_runs.append(child_run) - - @abstractmethod - def _persist_run(self, run: Run) -> None: - """Persist a run.""" - - def _start_trace(self, run: Run) -> None: - """Start a trace for a run.""" - if run.parent_run_id: - parent_run = self.run_map.get(str(run.parent_run_id)) - if parent_run: - self._add_child_run(parent_run, run) - parent_run.child_execution_order = max( - parent_run.child_execution_order, run.child_execution_order - ) - else: - logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") - self.run_map[str(run.id)] = run - self._on_run_create(run) - - def _end_trace(self, run: Run) -> None: - """End a trace for a run.""" - if not run.parent_run_id: - self._persist_run(run) - else: - parent_run = self.run_map.get(str(run.parent_run_id)) - if parent_run is None: - logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") - elif ( - run.child_execution_order is not None - and parent_run.child_execution_order is not None - and run.child_execution_order > parent_run.child_execution_order - ): - parent_run.child_execution_order = run.child_execution_order - self.run_map.pop(str(run.id)) - self._on_run_update(run) - - def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int: - """Get the execution order for a run.""" - if parent_run_id is None: - return 1 - - parent_run = self.run_map.get(parent_run_id) - if parent_run is None: - logger.debug(f"Parent run with UUID {parent_run_id} not found.") - return 1 - if parent_run.child_execution_order is None: - raise TracerException( - f"Parent run with UUID {parent_run_id} has no child execution order." - ) - - return parent_run.child_execution_order + 1 - - def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - *, - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> Run: - """Start a trace for an LLM run.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - llm_run = Run( - id=run_id, - parent_run_id=parent_run_id, - serialized=serialized, - inputs={"prompts": prompts}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - run_type="llm", - tags=tags or [], - name=name, - ) - self._start_trace(llm_run) - self._on_llm_start(llm_run) - return llm_run - - def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Run: - """Run on new LLM token. Only available when streaming is enabled.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_new_token callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") - event_kwargs: Dict[str, Any] = {"token": token} - if chunk: - event_kwargs["chunk"] = chunk - llm_run.events.append( - { - "name": "new_token", - "time": datetime.utcnow(), - "kwargs": event_kwargs, - }, - ) - self._on_llm_new_token(llm_run, token, chunk) - return llm_run - - def on_retry( - self, - retry_state: RetryCallState, - *, - run_id: UUID, - **kwargs: Any, - ) -> Run: - if not run_id: - raise TracerException("No run_id provided for on_retry callback.") - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None: - raise TracerException("No Run found to be traced for on_retry") - retry_d: Dict[str, Any] = { - "slept": retry_state.idle_for, - "attempt": retry_state.attempt_number, - } - if retry_state.outcome is None: - retry_d["outcome"] = "N/A" - elif retry_state.outcome.failed: - retry_d["outcome"] = "failed" - exception = retry_state.outcome.exception() - retry_d["exception"] = str(exception) - retry_d["exception_type"] = exception.__class__.__name__ - else: - retry_d["outcome"] = "success" - retry_d["result"] = str(retry_state.outcome.result()) - llm_run.events.append( - { - "name": "retry", - "time": datetime.utcnow(), - "kwargs": retry_d, - }, - ) - return llm_run - - def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: - """End a trace for an LLM run.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_end callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") - llm_run.outputs = response.dict() - for i, generations in enumerate(response.generations): - for j, generation in enumerate(generations): - output_generation = llm_run.outputs["generations"][i][j] - if "message" in output_generation: - output_generation["message"] = dumpd( - cast(ChatGeneration, generation).message - ) - llm_run.end_time = datetime.utcnow() - llm_run.events.append({"name": "end", "time": llm_run.end_time}) - self._end_trace(llm_run) - self._on_llm_end(llm_run) - return llm_run - - def on_llm_error( - self, - error: BaseException, - *, - run_id: UUID, - **kwargs: Any, - ) -> Run: - """Handle an error for an LLM run.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_error callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") - llm_run.error = repr(error) - llm_run.end_time = datetime.utcnow() - llm_run.events.append({"name": "error", "time": llm_run.end_time}) - self._end_trace(llm_run) - self._on_chain_error(llm_run) - return llm_run - - def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - *, - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - run_type: Optional[str] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> Run: - """Start a trace for a chain run.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - chain_run = Run( - id=run_id, - parent_run_id=parent_run_id, - serialized=serialized, - inputs=inputs if isinstance(inputs, dict) else {"input": inputs}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - child_runs=[], - run_type=run_type or "chain", - name=name, - tags=tags or [], - ) - self._start_trace(chain_run) - self._on_chain_start(chain_run) - return chain_run - - def on_chain_end( - self, - outputs: Dict[str, Any], - *, - run_id: UUID, - inputs: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Run: - """End a trace for a chain run.""" - if not run_id: - raise TracerException("No run_id provided for on_chain_end callback.") - chain_run = self.run_map.get(str(run_id)) - if chain_run is None: - raise TracerException(f"No chain Run found to be traced for {run_id}") - - chain_run.outputs = ( - outputs if isinstance(outputs, dict) else {"output": outputs} - ) - chain_run.end_time = datetime.utcnow() - chain_run.events.append({"name": "end", "time": chain_run.end_time}) - if inputs is not None: - chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} - self._end_trace(chain_run) - self._on_chain_end(chain_run) - return chain_run - - def on_chain_error( - self, - error: BaseException, - *, - inputs: Optional[Dict[str, Any]] = None, - run_id: UUID, - **kwargs: Any, - ) -> Run: - """Handle an error for a chain run.""" - if not run_id: - raise TracerException("No run_id provided for on_chain_error callback.") - chain_run = self.run_map.get(str(run_id)) - if chain_run is None: - raise TracerException(f"No chain Run found to be traced for {run_id}") - - chain_run.error = repr(error) - chain_run.end_time = datetime.utcnow() - chain_run.events.append({"name": "error", "time": chain_run.end_time}) - if inputs is not None: - chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} - self._end_trace(chain_run) - self._on_chain_error(chain_run) - return chain_run - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> Run: - """Start a trace for a tool run.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - tool_run = Run( - id=run_id, - parent_run_id=parent_run_id, - serialized=serialized, - inputs={"input": input_str}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - child_runs=[], - run_type="tool", - tags=tags or [], - name=name, - ) - self._start_trace(tool_run) - self._on_tool_start(tool_run) - return tool_run - - def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run: - """End a trace for a tool run.""" - if not run_id: - raise TracerException("No run_id provided for on_tool_end callback.") - tool_run = self.run_map.get(str(run_id)) - if tool_run is None or tool_run.run_type != "tool": - raise TracerException(f"No tool Run found to be traced for {run_id}") - - tool_run.outputs = {"output": output} - tool_run.end_time = datetime.utcnow() - tool_run.events.append({"name": "end", "time": tool_run.end_time}) - self._end_trace(tool_run) - self._on_tool_end(tool_run) - return tool_run - - def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - **kwargs: Any, - ) -> Run: - """Handle an error for a tool run.""" - if not run_id: - raise TracerException("No run_id provided for on_tool_error callback.") - tool_run = self.run_map.get(str(run_id)) - if tool_run is None or tool_run.run_type != "tool": - raise TracerException(f"No tool Run found to be traced for {run_id}") - - tool_run.error = repr(error) - tool_run.end_time = datetime.utcnow() - tool_run.events.append({"name": "error", "time": tool_run.end_time}) - self._end_trace(tool_run) - self._on_tool_error(tool_run) - return tool_run - - def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> Run: - """Run when Retriever starts running.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - retrieval_run = Run( - id=run_id, - name=name or "Retriever", - parent_run_id=parent_run_id, - serialized=serialized, - inputs={"query": query}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - tags=tags, - child_runs=[], - run_type="retriever", - ) - self._start_trace(retrieval_run) - self._on_retriever_start(retrieval_run) - return retrieval_run - - def on_retriever_error( - self, - error: BaseException, - *, - run_id: UUID, - **kwargs: Any, - ) -> Run: - """Run when Retriever errors.""" - if not run_id: - raise TracerException("No run_id provided for on_retriever_error callback.") - retrieval_run = self.run_map.get(str(run_id)) - if retrieval_run is None or retrieval_run.run_type != "retriever": - raise TracerException(f"No retriever Run found to be traced for {run_id}") - - retrieval_run.error = repr(error) - retrieval_run.end_time = datetime.utcnow() - retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time}) - self._end_trace(retrieval_run) - self._on_retriever_error(retrieval_run) - return retrieval_run - - def on_retriever_end( - self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any - ) -> Run: - """Run when Retriever ends running.""" - if not run_id: - raise TracerException("No run_id provided for on_retriever_end callback.") - retrieval_run = self.run_map.get(str(run_id)) - if retrieval_run is None or retrieval_run.run_type != "retriever": - raise TracerException(f"No retriever Run found to be traced for {run_id}") - retrieval_run.outputs = {"documents": documents} - retrieval_run.end_time = datetime.utcnow() - retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time}) - self._end_trace(retrieval_run) - self._on_retriever_end(retrieval_run) - return retrieval_run - - def __deepcopy__(self, memo: dict) -> BaseTracer: - """Deepcopy the tracer.""" - return self - - def __copy__(self) -> BaseTracer: - """Copy the tracer.""" - return self - - def _on_run_create(self, run: Run) -> None: - """Process a run upon creation.""" - - def _on_run_update(self, run: Run) -> None: - """Process a run upon update.""" - - def _on_llm_start(self, run: Run) -> None: - """Process the LLM Run upon start.""" - - def _on_llm_new_token( - self, - run: Run, - token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], - ) -> None: - """Process new LLM token.""" - - def _on_llm_end(self, run: Run) -> None: - """Process the LLM Run.""" - - def _on_llm_error(self, run: Run) -> None: - """Process the LLM Run upon error.""" - - def _on_chain_start(self, run: Run) -> None: - """Process the Chain Run upon start.""" - - def _on_chain_end(self, run: Run) -> None: - """Process the Chain Run.""" - - def _on_chain_error(self, run: Run) -> None: - """Process the Chain Run upon error.""" - - def _on_tool_start(self, run: Run) -> None: - """Process the Tool Run upon start.""" - - def _on_tool_end(self, run: Run) -> None: - """Process the Tool Run.""" - - def _on_tool_error(self, run: Run) -> None: - """Process the Tool Run upon error.""" - - def _on_chat_model_start(self, run: Run) -> None: - """Process the Chat Model Run upon start.""" - - def _on_retriever_start(self, run: Run) -> None: - """Process the Retriever Run upon start.""" - - def _on_retriever_end(self, run: Run) -> None: - """Process the Retriever Run.""" - - def _on_retriever_error(self, run: Run) -> None: - """Process the Retriever Run upon error.""" +__all__ = ["BaseTracer", "TracerException"] diff --git a/libs/langchain/langchain/callbacks/tracers/evaluation.py b/libs/langchain/langchain/callbacks/tracers/evaluation.py index b78b322ddda..8384ea6557d 100644 --- a/libs/langchain/langchain/callbacks/tracers/evaluation.py +++ b/libs/langchain/langchain/callbacks/tracers/evaluation.py @@ -1,222 +1,7 @@ """A tracer that runs evaluators over completed runs.""" -from __future__ import annotations +from langchain.schema.callbacks.tracers.evaluation import ( + EvaluatorCallbackHandler, + wait_for_all_evaluators, +) -import logging -import threading -import weakref -from concurrent.futures import Future, ThreadPoolExecutor, wait -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast -from uuid import UUID - -import langsmith -from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults - -from langchain.callbacks import manager -from langchain.callbacks.tracers import langchain as langchain_tracer -from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.langchain import _get_executor -from langchain.callbacks.tracers.schemas import Run - -logger = logging.getLogger(__name__) - -_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet() - - -def wait_for_all_evaluators() -> None: - """Wait for all tracers to finish.""" - global _TRACERS - for tracer in list(_TRACERS): - if tracer is not None: - tracer.wait_for_futures() - - -class EvaluatorCallbackHandler(BaseTracer): - """A tracer that runs a run evaluator whenever a run is persisted. - - Parameters - ---------- - evaluators : Sequence[RunEvaluator] - The run evaluators to apply to all top level runs. - client : LangSmith Client, optional - The LangSmith client instance to use for evaluating the runs. - If not specified, a new instance will be created. - example_id : Union[UUID, str], optional - The example ID to be associated with the runs. - project_name : str, optional - The LangSmith project name to be organize eval chain runs under. - - Attributes - ---------- - example_id : Union[UUID, None] - The example ID associated with the runs. - client : Client - The LangSmith client instance used for evaluating the runs. - evaluators : Sequence[RunEvaluator] - The sequence of run evaluators to be executed. - executor : ThreadPoolExecutor - The thread pool executor used for running the evaluators. - futures : Set[Future] - The set of futures representing the running evaluators. - skip_unfinished : bool - Whether to skip runs that are not finished or raised - an error. - project_name : Optional[str] - The LangSmith project name to be organize eval chain runs under. - """ - - name = "evaluator_callback_handler" - - def __init__( - self, - evaluators: Sequence[langsmith.RunEvaluator], - client: Optional[langsmith.Client] = None, - example_id: Optional[Union[UUID, str]] = None, - skip_unfinished: bool = True, - project_name: Optional[str] = "evaluators", - max_concurrency: Optional[int] = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.example_id = ( - UUID(example_id) if isinstance(example_id, str) else example_id - ) - self.client = client or langchain_tracer.get_client() - self.evaluators = evaluators - if max_concurrency is None: - self.executor: Optional[ThreadPoolExecutor] = _get_executor() - elif max_concurrency > 0: - self.executor = ThreadPoolExecutor(max_workers=max_concurrency) - weakref.finalize( - self, - lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True), - ) - else: - self.executor = None - self.futures: weakref.WeakSet[Future] = weakref.WeakSet() - self.skip_unfinished = skip_unfinished - self.project_name = project_name - self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {} - self.lock = threading.Lock() - global _TRACERS - _TRACERS.add(self) - - def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None: - """Evaluate the run in the project. - - Parameters - ---------- - run : Run - The run to be evaluated. - evaluator : RunEvaluator - The evaluator to use for evaluating the run. - - """ - try: - if self.project_name is None: - eval_result = self.client.evaluate_run(run, evaluator) - eval_results = [eval_result] - with manager.tracing_v2_enabled( - project_name=self.project_name, tags=["eval"], client=self.client - ) as cb: - reference_example = ( - self.client.read_example(run.reference_example_id) - if run.reference_example_id - else None - ) - evaluation_result = evaluator.evaluate_run( - run, - example=reference_example, - ) - eval_results = self._log_evaluation_feedback( - evaluation_result, - run, - source_run_id=cb.latest_run.id if cb.latest_run else None, - ) - except Exception as e: - logger.error( - f"Error evaluating run {run.id} with " - f"{evaluator.__class__.__name__}: {repr(e)}", - exc_info=True, - ) - raise e - example_id = str(run.reference_example_id) - with self.lock: - for res in eval_results: - run_id = ( - str(getattr(res, "target_run_id")) - if hasattr(res, "target_run_id") - else str(run.id) - ) - self.logged_eval_results.setdefault((run_id, example_id), []).append( - res - ) - - def _select_eval_results( - self, - results: Union[EvaluationResult, EvaluationResults], - ) -> List[EvaluationResult]: - if isinstance(results, EvaluationResult): - results_ = [results] - elif isinstance(results, dict) and "results" in results: - results_ = cast(List[EvaluationResult], results["results"]) - else: - raise TypeError( - f"Invalid evaluation result type {type(results)}." - " Expected EvaluationResult or EvaluationResults." - ) - return results_ - - def _log_evaluation_feedback( - self, - evaluator_response: Union[EvaluationResult, EvaluationResults], - run: Run, - source_run_id: Optional[UUID] = None, - ) -> List[EvaluationResult]: - results = self._select_eval_results(evaluator_response) - for res in results: - source_info_: Dict[str, Any] = {} - if res.evaluator_info: - source_info_ = {**res.evaluator_info, **source_info_} - run_id_ = ( - getattr(res, "target_run_id") - if hasattr(res, "target_run_id") and res.target_run_id is not None - else run.id - ) - self.client.create_feedback( - run_id_, - res.key, - score=res.score, - value=res.value, - comment=res.comment, - correction=res.correction, - source_info=source_info_, - source_run_id=res.source_run_id or source_run_id, - feedback_source_type=langsmith.schemas.FeedbackSourceType.MODEL, - ) - return results - - def _persist_run(self, run: Run) -> None: - """Run the evaluator on the run. - - Parameters - ---------- - run : Run - The run to be evaluated. - - """ - if self.skip_unfinished and not run.outputs: - logger.debug(f"Skipping unfinished run {run.id}") - return - run_ = run.copy() - run_.reference_example_id = self.example_id - for evaluator in self.evaluators: - if self.executor is None: - self._evaluate_in_project(run_, evaluator) - else: - self.futures.add( - self.executor.submit(self._evaluate_in_project, run_, evaluator) - ) - - def wait_for_futures(self) -> None: - """Wait for all futures to complete.""" - wait(self.futures) +__all__ = ["wait_for_all_evaluators", "EvaluatorCallbackHandler"] diff --git a/libs/langchain/langchain/callbacks/tracers/langchain.py b/libs/langchain/langchain/callbacks/tracers/langchain.py index 54619243cfe..1cfe3ffc0a0 100644 --- a/libs/langchain/langchain/callbacks/tracers/langchain.py +++ b/libs/langchain/langchain/callbacks/tracers/langchain.py @@ -1,262 +1,8 @@ """A Tracer implementation that records to LangChain endpoint.""" -from __future__ import annotations -import logging -import weakref -from concurrent.futures import Future, ThreadPoolExecutor, wait -from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Union -from uuid import UUID - -from langsmith import Client -from langsmith import utils as ls_utils -from tenacity import ( - Retrying, - retry_if_exception_type, - stop_after_attempt, - wait_exponential_jitter, +from langchain.schema.callbacks.tracers.langchain import ( + LangChainTracer, + wait_for_all_tracers, ) -from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import Run -from langchain.env import get_runtime_environment -from langchain.load.dump import dumpd -from langchain.schema.messages import BaseMessage - -logger = logging.getLogger(__name__) -_LOGGED = set() -_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet() -_CLIENT: Optional[Client] = None -_EXECUTOR: Optional[ThreadPoolExecutor] = None - - -def log_error_once(method: str, exception: Exception) -> None: - """Log an error once.""" - global _LOGGED - if (method, type(exception)) in _LOGGED: - return - _LOGGED.add((method, type(exception))) - logger.error(exception) - - -def wait_for_all_tracers() -> None: - """Wait for all tracers to finish.""" - global _TRACERS - for tracer in list(_TRACERS): - if tracer is not None: - tracer.wait_for_futures() - - -def get_client() -> Client: - """Get the client.""" - global _CLIENT - if _CLIENT is None: - _CLIENT = Client() - return _CLIENT - - -def _get_executor() -> ThreadPoolExecutor: - """Get the executor.""" - global _EXECUTOR - if _EXECUTOR is None: - _EXECUTOR = ThreadPoolExecutor() - return _EXECUTOR - - -def _copy(run: Run) -> Run: - """Copy a run.""" - try: - return run.copy(deep=True) - except TypeError: - # Fallback in case the object contains a lock or other - # non-pickleable object - return run.copy() - - -class LangChainTracer(BaseTracer): - """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - - def __init__( - self, - example_id: Optional[Union[UUID, str]] = None, - project_name: Optional[str] = None, - client: Optional[Client] = None, - tags: Optional[List[str]] = None, - use_threading: bool = True, - **kwargs: Any, - ) -> None: - """Initialize the LangChain tracer.""" - super().__init__(**kwargs) - self.example_id = ( - UUID(example_id) if isinstance(example_id, str) else example_id - ) - self.project_name = project_name or ls_utils.get_tracer_project() - self.client = client or get_client() - self._futures: weakref.WeakSet[Future] = weakref.WeakSet() - self.tags = tags or [] - self.executor = _get_executor() if use_threading else None - self.latest_run: Optional[Run] = None - global _TRACERS - _TRACERS.add(self) - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - *, - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Start a trace for an LLM run.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - chat_model_run = Run( - id=run_id, - parent_run_id=parent_run_id, - serialized=serialized, - inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - run_type="llm", - tags=tags, - name=name, - ) - self._start_trace(chat_model_run) - self._on_chat_model_start(chat_model_run) - - def _persist_run(self, run: Run) -> None: - run_ = run.copy() - run_.reference_example_id = self.example_id - self.latest_run = run_ - - def get_run_url(self) -> str: - """Get the LangSmith root run URL""" - if not self.latest_run: - raise ValueError("No traced run found.") - # If this is the first run in a project, the project may not yet be created. - # This method is only really useful for debugging flows, so we will assume - # there is some tolerace for latency. - for attempt in Retrying( - stop=stop_after_attempt(5), - wait=wait_exponential_jitter(), - retry=retry_if_exception_type(ls_utils.LangSmithError), - ): - with attempt: - return self.client.get_run_url( - run=self.latest_run, project_name=self.project_name - ) - raise ValueError("Failed to get run URL.") - - def _get_tags(self, run: Run) -> List[str]: - """Get combined tags for a run.""" - tags = set(run.tags or []) - tags.update(self.tags or []) - return list(tags) - - def _persist_run_single(self, run: Run) -> None: - """Persist a run.""" - run_dict = run.dict(exclude={"child_runs"}) - run_dict["tags"] = self._get_tags(run) - extra = run_dict.get("extra", {}) - extra["runtime"] = get_runtime_environment() - run_dict["extra"] = extra - try: - self.client.create_run(**run_dict, project_name=self.project_name) - except Exception as e: - # Errors are swallowed by the thread executor so we need to log them here - log_error_once("post", e) - raise - - def _update_run_single(self, run: Run) -> None: - """Update a run.""" - try: - run_dict = run.dict() - run_dict["tags"] = self._get_tags(run) - self.client.update_run(run.id, **run_dict) - except Exception as e: - # Errors are swallowed by the thread executor so we need to log them here - log_error_once("patch", e) - raise - - def _submit(self, function: Callable[[Run], None], run: Run) -> None: - """Submit a function to the executor.""" - if self.executor is None: - function(run) - else: - self._futures.add(self.executor.submit(function, run)) - - def _on_llm_start(self, run: Run) -> None: - """Persist an LLM run.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_chat_model_start(self, run: Run) -> None: - """Persist an LLM run.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_llm_end(self, run: Run) -> None: - """Process the LLM Run.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_llm_error(self, run: Run) -> None: - """Process the LLM Run upon error.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_chain_start(self, run: Run) -> None: - """Process the Chain Run upon start.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_chain_end(self, run: Run) -> None: - """Process the Chain Run.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_chain_error(self, run: Run) -> None: - """Process the Chain Run upon error.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_tool_start(self, run: Run) -> None: - """Process the Tool Run upon start.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_tool_end(self, run: Run) -> None: - """Process the Tool Run.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_tool_error(self, run: Run) -> None: - """Process the Tool Run upon error.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_retriever_start(self, run: Run) -> None: - """Process the Retriever Run upon start.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_retriever_end(self, run: Run) -> None: - """Process the Retriever Run.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_retriever_error(self, run: Run) -> None: - """Process the Retriever Run upon error.""" - self._submit(self._update_run_single, _copy(run)) - - def wait_for_futures(self) -> None: - """Wait for the given futures to complete.""" - wait(self._futures) +__all__ = ["LangChainTracer", "wait_for_all_tracers"] diff --git a/libs/langchain/langchain/callbacks/tracers/langchain_v1.py b/libs/langchain/langchain/callbacks/tracers/langchain_v1.py index b83c2ccd2b2..056b5c4786a 100644 --- a/libs/langchain/langchain/callbacks/tracers/langchain_v1.py +++ b/libs/langchain/langchain/callbacks/tracers/langchain_v1.py @@ -1,185 +1,3 @@ -from __future__ import annotations +from langchain.schema.callbacks.tracers.langchain_v1 import LangChainTracerV1 -import logging -import os -from typing import Any, Dict, Optional, Union - -import requests - -from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import ( - ChainRun, - LLMRun, - Run, - ToolRun, - TracerSession, - TracerSessionV1, - TracerSessionV1Base, -) -from langchain.schema.messages import get_buffer_string -from langchain.utils import raise_for_status_with_text - -logger = logging.getLogger(__name__) - - -def get_headers() -> Dict[str, Any]: - """Get the headers for the LangChain API.""" - headers: Dict[str, Any] = {"Content-Type": "application/json"} - if os.getenv("LANGCHAIN_API_KEY"): - headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") - return headers - - -def _get_endpoint() -> str: - return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") - - -class LangChainTracerV1(BaseTracer): - """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LangChain tracer.""" - super().__init__(**kwargs) - self.session: Optional[TracerSessionV1] = None - self._endpoint = _get_endpoint() - self._headers = get_headers() - - def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]: - session = self.session or self.load_default_session() - if not isinstance(session, TracerSessionV1): - raise ValueError( - "LangChainTracerV1 is not compatible with" - f" session of type {type(session)}" - ) - - if run.run_type == "llm": - if "prompts" in run.inputs: - prompts = run.inputs["prompts"] - elif "messages" in run.inputs: - prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]] - else: - raise ValueError("No prompts found in LLM run inputs") - return LLMRun( - uuid=str(run.id) if run.id else None, - parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, - start_time=run.start_time, - end_time=run.end_time, - extra=run.extra, - execution_order=run.execution_order, - child_execution_order=run.child_execution_order, - serialized=run.serialized, - session_id=session.id, - error=run.error, - prompts=prompts, - response=run.outputs if run.outputs else None, - ) - if run.run_type == "chain": - child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] - return ChainRun( - uuid=str(run.id) if run.id else None, - parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, - start_time=run.start_time, - end_time=run.end_time, - execution_order=run.execution_order, - child_execution_order=run.child_execution_order, - serialized=run.serialized, - session_id=session.id, - inputs=run.inputs, - outputs=run.outputs, - error=run.error, - extra=run.extra, - child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], - child_chain_runs=[ - run for run in child_runs if isinstance(run, ChainRun) - ], - child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], - ) - if run.run_type == "tool": - child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] - return ToolRun( - uuid=str(run.id) if run.id else None, - parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, - start_time=run.start_time, - end_time=run.end_time, - execution_order=run.execution_order, - child_execution_order=run.child_execution_order, - serialized=run.serialized, - session_id=session.id, - action=str(run.serialized), - tool_input=run.inputs.get("input", ""), - output=None if run.outputs is None else run.outputs.get("output"), - error=run.error, - extra=run.extra, - child_chain_runs=[ - run for run in child_runs if isinstance(run, ChainRun) - ], - child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], - child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], - ) - raise ValueError(f"Unknown run type: {run.run_type}") - - def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: - """Persist a run.""" - if isinstance(run, Run): - v1_run = self._convert_to_v1_run(run) - else: - v1_run = run - if isinstance(v1_run, LLMRun): - endpoint = f"{self._endpoint}/llm-runs" - elif isinstance(v1_run, ChainRun): - endpoint = f"{self._endpoint}/chain-runs" - else: - endpoint = f"{self._endpoint}/tool-runs" - - try: - response = requests.post( - endpoint, - data=v1_run.json(), - headers=self._headers, - ) - raise_for_status_with_text(response) - except Exception as e: - logger.warning(f"Failed to persist run: {e}") - - def _persist_session( - self, session_create: TracerSessionV1Base - ) -> Union[TracerSessionV1, TracerSession]: - """Persist a session.""" - try: - r = requests.post( - f"{self._endpoint}/sessions", - data=session_create.json(), - headers=self._headers, - ) - session = TracerSessionV1(id=r.json()["id"], **session_create.dict()) - except Exception as e: - logger.warning(f"Failed to create session, using default session: {e}") - session = TracerSessionV1(id=1, **session_create.dict()) - return session - - def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1: - """Load a session from the tracer.""" - try: - url = f"{self._endpoint}/sessions" - if session_name: - url += f"?name={session_name}" - r = requests.get(url, headers=self._headers) - - tracer_session = TracerSessionV1(**r.json()[0]) - except Exception as e: - session_type = "default" if not session_name else session_name - logger.warning( - f"Failed to load {session_type} session, using empty session: {e}" - ) - tracer_session = TracerSessionV1(id=1) - - self.session = tracer_session - return tracer_session - - def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]: - """Load a session with the given name from the tracer.""" - return self._load_session(session_name) - - def load_default_session(self) -> Union[TracerSessionV1, TracerSession]: - """Load the default tracing session and set it as the Tracer's session.""" - return self._load_session("default") +__all__ = ["LangChainTracerV1"] diff --git a/libs/langchain/langchain/callbacks/tracers/log_stream.py b/libs/langchain/langchain/callbacks/tracers/log_stream.py index 1bca4098f83..6630dd6e53f 100644 --- a/libs/langchain/langchain/callbacks/tracers/log_stream.py +++ b/libs/langchain/langchain/callbacks/tracers/log_stream.py @@ -1,311 +1,9 @@ -from __future__ import annotations - -import math -import threading -from collections import defaultdict -from typing import ( - Any, - AsyncIterator, - Dict, - List, - Optional, - Sequence, - TypedDict, - Union, +from langchain.schema.callbacks.tracers.log_stream import ( + LogEntry, + LogStreamCallbackHandler, + RunLog, + RunLogPatch, + RunState, ) -from uuid import UUID -import jsonpatch -from anyio import create_memory_object_stream - -from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import Run -from langchain.load.load import load -from langchain.schema.output import ChatGenerationChunk, GenerationChunk - - -class LogEntry(TypedDict): - """A single entry in the run log.""" - - id: str - """ID of the sub-run.""" - name: str - """Name of the object being run.""" - type: str - """Type of the object being run, eg. prompt, chain, llm, etc.""" - tags: List[str] - """List of tags for the run.""" - metadata: Dict[str, Any] - """Key-value pairs of metadata for the run.""" - start_time: str - """ISO-8601 timestamp of when the run started.""" - - streamed_output_str: List[str] - """List of LLM tokens streamed by this run, if applicable.""" - final_output: Optional[Any] - """Final output of this run. - Only available after the run has finished successfully.""" - end_time: Optional[str] - """ISO-8601 timestamp of when the run ended. - Only available after the run has finished.""" - - -class RunState(TypedDict): - """State of the run.""" - - id: str - """ID of the run.""" - streamed_output: List[Any] - """List of output chunks streamed by Runnable.stream()""" - final_output: Optional[Any] - """Final output of the run, usually the result of aggregating (`+`) streamed_output. - Only available after the run has finished successfully.""" - - logs: Dict[str, LogEntry] - """Map of run names to sub-runs. If filters were supplied, this list will - contain only the runs that matched the filters.""" - - -class RunLogPatch: - """A patch to the run log.""" - - ops: List[Dict[str, Any]] - """List of jsonpatch operations, which describe how to create the run state - from an empty dict. This is the minimal representation of the log, designed to - be serialized as JSON and sent over the wire to reconstruct the log on the other - side. Reconstruction of the state can be done with any jsonpatch-compliant library, - see https://jsonpatch.com for more information.""" - - def __init__(self, *ops: Dict[str, Any]) -> None: - self.ops = list(ops) - - def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: - if type(other) == RunLogPatch: - ops = self.ops + other.ops - state = jsonpatch.apply_patch(None, ops) - return RunLog(*ops, state=state) - - raise TypeError( - f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" - ) - - def __repr__(self) -> str: - from pprint import pformat - - # 1:-1 to get rid of the [] around the list - return f"RunLogPatch({pformat(self.ops)[1:-1]})" - - def __eq__(self, other: object) -> bool: - return isinstance(other, RunLogPatch) and self.ops == other.ops - - -class RunLog(RunLogPatch): - """A run log.""" - - state: RunState - """Current state of the log, obtained from applying all ops in sequence.""" - - def __init__(self, *ops: Dict[str, Any], state: RunState) -> None: - super().__init__(*ops) - self.state = state - - def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: - if type(other) == RunLogPatch: - ops = self.ops + other.ops - state = jsonpatch.apply_patch(self.state, other.ops) - return RunLog(*ops, state=state) - - raise TypeError( - f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" - ) - - def __repr__(self) -> str: - from pprint import pformat - - return f"RunLog({pformat(self.state)})" - - -class LogStreamCallbackHandler(BaseTracer): - """A tracer that streams run logs to a stream.""" - - def __init__( - self, - *, - auto_close: bool = True, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, - ) -> None: - super().__init__() - - self.auto_close = auto_close - self.include_names = include_names - self.include_types = include_types - self.include_tags = include_tags - self.exclude_names = exclude_names - self.exclude_types = exclude_types - self.exclude_tags = exclude_tags - - send_stream, receive_stream = create_memory_object_stream( - math.inf, item_type=RunLogPatch - ) - self.lock = threading.Lock() - self.send_stream = send_stream - self.receive_stream = receive_stream - self._key_map_by_run_id: Dict[UUID, str] = {} - self._counter_map_by_name: Dict[str, int] = defaultdict(int) - self.root_id: Optional[UUID] = None - - def __aiter__(self) -> AsyncIterator[RunLogPatch]: - return self.receive_stream.__aiter__() - - def include_run(self, run: Run) -> bool: - if run.id == self.root_id: - return False - - run_tags = run.tags or [] - - if ( - self.include_names is None - and self.include_types is None - and self.include_tags is None - ): - include = True - else: - include = False - - if self.include_names is not None: - include = include or run.name in self.include_names - if self.include_types is not None: - include = include or run.run_type in self.include_types - if self.include_tags is not None: - include = include or any(tag in self.include_tags for tag in run_tags) - - if self.exclude_names is not None: - include = include and run.name not in self.exclude_names - if self.exclude_types is not None: - include = include and run.run_type not in self.exclude_types - if self.exclude_tags is not None: - include = include and all(tag not in self.exclude_tags for tag in run_tags) - - return include - - def _persist_run(self, run: Run) -> None: - # This is a legacy method only called once for an entire run tree - # therefore not useful here - pass - - def _on_run_create(self, run: Run) -> None: - """Start a run.""" - if self.root_id is None: - self.root_id = run.id - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "replace", - "path": "", - "value": RunState( - id=str(run.id), - streamed_output=[], - final_output=None, - logs={}, - ), - } - ) - ) - - if not self.include_run(run): - return - - # Determine previous index, increment by 1 - with self.lock: - self._counter_map_by_name[run.name] += 1 - count = self._counter_map_by_name[run.name] - self._key_map_by_run_id[run.id] = ( - run.name if count == 1 else f"{run.name}:{count}" - ) - - # Add the run to the stream - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{self._key_map_by_run_id[run.id]}", - "value": LogEntry( - id=str(run.id), - name=run.name, - type=run.run_type, - tags=run.tags or [], - metadata=(run.extra or {}).get("metadata", {}), - start_time=run.start_time.isoformat(timespec="milliseconds"), - streamed_output_str=[], - final_output=None, - end_time=None, - ), - } - ) - ) - - def _on_run_update(self, run: Run) -> None: - """Finish a run.""" - try: - index = self._key_map_by_run_id.get(run.id) - - if index is None: - return - - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{index}/final_output", - # to undo the dumpd done by some runnables / tracer / etc - "value": load(run.outputs), - }, - { - "op": "add", - "path": f"/logs/{index}/end_time", - "value": run.end_time.isoformat(timespec="milliseconds") - if run.end_time is not None - else None, - }, - ) - ) - finally: - if run.id == self.root_id: - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "replace", - "path": "/final_output", - "value": load(run.outputs), - } - ) - ) - if self.auto_close: - self.send_stream.close() - - def _on_llm_new_token( - self, - run: Run, - token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], - ) -> None: - """Process new LLM token.""" - index = self._key_map_by_run_id.get(run.id) - - if index is None: - return - - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{index}/streamed_output_str/-", - "value": token, - } - ) - ) +__all__ = ["LogEntry", "RunState", "RunLog", "RunLogPatch", "LogStreamCallbackHandler"] diff --git a/libs/langchain/langchain/callbacks/tracers/root_listeners.py b/libs/langchain/langchain/callbacks/tracers/root_listeners.py index af489bfe9ab..2eceb6db7cb 100644 --- a/libs/langchain/langchain/callbacks/tracers/root_listeners.py +++ b/libs/langchain/langchain/callbacks/tracers/root_listeners.py @@ -1,54 +1,3 @@ -from typing import Callable, Optional, Union -from uuid import UUID +from langchain.schema.callbacks.tracers.root_listeners import RootListenersTracer -from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import Run -from langchain.schema.runnable.config import ( - RunnableConfig, - call_func_with_variable_args, -) - -Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - - -class RootListenersTracer(BaseTracer): - def __init__( - self, - *, - config: RunnableConfig, - on_start: Optional[Listener], - on_end: Optional[Listener], - on_error: Optional[Listener], - ) -> None: - super().__init__() - - self.config = config - self._arg_on_start = on_start - self._arg_on_end = on_end - self._arg_on_error = on_error - self.root_id: Optional[UUID] = None - - def _persist_run(self, run: Run) -> None: - # This is a legacy method only called once for an entire run tree - # therefore not useful here - pass - - def _on_run_create(self, run: Run) -> None: - if self.root_id is not None: - return - - self.root_id = run.id - - if self._arg_on_start is not None: - call_func_with_variable_args(self._arg_on_start, run, self.config) - - def _on_run_update(self, run: Run) -> None: - if run.id != self.root_id: - return - - if run.error is None: - if self._arg_on_end is not None: - call_func_with_variable_args(self._arg_on_end, run, self.config) - else: - if self._arg_on_error is not None: - call_func_with_variable_args(self._arg_on_error, run, self.config) +__all__ = ["RootListenersTracer"] diff --git a/libs/langchain/langchain/callbacks/tracers/run_collector.py b/libs/langchain/langchain/callbacks/tracers/run_collector.py index dc0d6eab1e5..da4b7ee8d8e 100644 --- a/libs/langchain/langchain/callbacks/tracers/run_collector.py +++ b/libs/langchain/langchain/callbacks/tracers/run_collector.py @@ -1,52 +1,3 @@ -"""A tracer that collects all nested runs in a list.""" +from langchain.schema.callbacks.tracers.run_collector import RunCollectorCallbackHandler -from typing import Any, List, Optional, Union -from uuid import UUID - -from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import Run - - -class RunCollectorCallbackHandler(BaseTracer): - """ - A tracer that collects all nested runs in a list. - - This tracer is useful for inspection and evaluation purposes. - - Parameters - ---------- - example_id : Optional[Union[UUID, str]], default=None - The ID of the example being traced. It can be either a UUID or a string. - """ - - name: str = "run-collector_callback_handler" - - def __init__( - self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any - ) -> None: - """ - Initialize the RunCollectorCallbackHandler. - - Parameters - ---------- - example_id : Optional[Union[UUID, str]], default=None - The ID of the example being traced. It can be either a UUID or a string. - """ - super().__init__(**kwargs) - self.example_id = ( - UUID(example_id) if isinstance(example_id, str) else example_id - ) - self.traced_runs: List[Run] = [] - - def _persist_run(self, run: Run) -> None: - """ - Persist a run by adding it to the traced_runs list. - - Parameters - ---------- - run : Run - The run to be persisted. - """ - run_ = run.copy() - run_.reference_example_id = self.example_id - self.traced_runs.append(run_) +__all__ = ["RunCollectorCallbackHandler"] diff --git a/libs/langchain/langchain/callbacks/tracers/schemas.py b/libs/langchain/langchain/callbacks/tracers/schemas.py index 4db455be2ea..b4445454891 100644 --- a/libs/langchain/langchain/callbacks/tracers/schemas.py +++ b/libs/langchain/langchain/callbacks/tracers/schemas.py @@ -1,129 +1,16 @@ -"""Schemas for tracers.""" -from __future__ import annotations - -import datetime -import warnings -from typing import Any, Dict, List, Optional, Type -from uuid import UUID - -from langsmith.schemas import RunBase as BaseRunV2 -from langsmith.schemas import RunTypeEnum as RunTypeEnumDep - -from langchain.pydantic_v1 import BaseModel, Field, root_validator -from langchain.schema import LLMResult - - -def RunTypeEnum() -> Type[RunTypeEnumDep]: - """RunTypeEnum.""" - warnings.warn( - "RunTypeEnum is deprecated. Please directly use a string instead" - " (e.g. 'llm', 'chain', 'tool').", - DeprecationWarning, - ) - return RunTypeEnumDep - - -class TracerSessionV1Base(BaseModel): - """Base class for TracerSessionV1.""" - - start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - name: Optional[str] = None - extra: Optional[Dict[str, Any]] = None - - -class TracerSessionV1Create(TracerSessionV1Base): - """Create class for TracerSessionV1.""" - - -class TracerSessionV1(TracerSessionV1Base): - """TracerSessionV1 schema.""" - - id: int - - -class TracerSessionBase(TracerSessionV1Base): - """Base class for TracerSession.""" - - tenant_id: UUID - - -class TracerSession(TracerSessionBase): - """TracerSessionV1 schema for the V2 API.""" - - id: UUID - - -class BaseRun(BaseModel): - """Base class for Run.""" - - uuid: str - parent_uuid: Optional[str] = None - start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - extra: Optional[Dict[str, Any]] = None - execution_order: int - child_execution_order: int - serialized: Dict[str, Any] - session_id: int - error: Optional[str] = None - - -class LLMRun(BaseRun): - """Class for LLMRun.""" - - prompts: List[str] - response: Optional[LLMResult] = None - - -class ChainRun(BaseRun): - """Class for ChainRun.""" - - inputs: Dict[str, Any] - outputs: Optional[Dict[str, Any]] = None - child_llm_runs: List[LLMRun] = Field(default_factory=list) - child_chain_runs: List[ChainRun] = Field(default_factory=list) - child_tool_runs: List[ToolRun] = Field(default_factory=list) - - -class ToolRun(BaseRun): - """Class for ToolRun.""" - - tool_input: str - output: Optional[str] = None - action: str - child_llm_runs: List[LLMRun] = Field(default_factory=list) - child_chain_runs: List[ChainRun] = Field(default_factory=list) - child_tool_runs: List[ToolRun] = Field(default_factory=list) - - -# Begin V2 API Schemas - - -class Run(BaseRunV2): - """Run schema for the V2 API in the Tracer.""" - - execution_order: int - child_execution_order: int - child_runs: List[Run] = Field(default_factory=list) - tags: Optional[List[str]] = Field(default_factory=list) - events: List[Dict[str, Any]] = Field(default_factory=list) - - @root_validator(pre=True) - def assign_name(cls, values: dict) -> dict: - """Assign name to the run.""" - if values.get("name") is None: - if "name" in values["serialized"]: - values["name"] = values["serialized"]["name"] - elif "id" in values["serialized"]: - values["name"] = values["serialized"]["id"][-1] - if values.get("events") is None: - values["events"] = [] - return values - - -ChainRun.update_forward_refs() -ToolRun.update_forward_refs() -Run.update_forward_refs() +from langchain.schema.callbacks.tracers.schemas import ( + BaseRun, + ChainRun, + LLMRun, + Run, + RunTypeEnum, + ToolRun, + TracerSession, + TracerSessionBase, + TracerSessionV1, + TracerSessionV1Base, + TracerSessionV1Create, +) __all__ = [ "BaseRun", diff --git a/libs/langchain/langchain/callbacks/tracers/stdout.py b/libs/langchain/langchain/callbacks/tracers/stdout.py index cd05f1898de..12e8a187da2 100644 --- a/libs/langchain/langchain/callbacks/tracers/stdout.py +++ b/libs/langchain/langchain/callbacks/tracers/stdout.py @@ -1,178 +1,6 @@ -import json -from typing import Any, Callable, List +from langchain.schema.callbacks.tracers.stdout import ( + ConsoleCallbackHandler, + FunctionCallbackHandler, +) -from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import Run -from langchain.utils.input import get_bolded_text, get_colored_text - - -def try_json_stringify(obj: Any, fallback: str) -> str: - """ - Try to stringify an object to JSON. - Args: - obj: Object to stringify. - fallback: Fallback string to return if the object cannot be stringified. - - Returns: - A JSON string if the object can be stringified, otherwise the fallback string. - - """ - try: - return json.dumps(obj, indent=2, ensure_ascii=False) - except Exception: - return fallback - - -def elapsed(run: Any) -> str: - """Get the elapsed time of a run. - - Args: - run: any object with a start_time and end_time attribute. - - Returns: - A string with the elapsed time in seconds or - milliseconds if time is less than a second. - - """ - elapsed_time = run.end_time - run.start_time - milliseconds = elapsed_time.total_seconds() * 1000 - if milliseconds < 1000: - return f"{milliseconds:.0f}ms" - return f"{(milliseconds / 1000):.2f}s" - - -class FunctionCallbackHandler(BaseTracer): - """Tracer that calls a function with a single str parameter.""" - - name: str = "function_callback_handler" - - def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None: - super().__init__(**kwargs) - self.function_callback = function - - def _persist_run(self, run: Run) -> None: - pass - - def get_parents(self, run: Run) -> List[Run]: - parents = [] - current_run = run - while current_run.parent_run_id: - parent = self.run_map.get(str(current_run.parent_run_id)) - if parent: - parents.append(parent) - current_run = parent - else: - break - return parents - - def get_breadcrumbs(self, run: Run) -> str: - parents = self.get_parents(run)[::-1] - string = " > ".join( - f"{parent.execution_order}:{parent.run_type}:{parent.name}" - if i != len(parents) - 1 - else f"{parent.execution_order}:{parent.run_type}:{parent.name}" - for i, parent in enumerate(parents + [run]) - ) - return string - - # logging methods - def _on_chain_start(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - run_type = run.run_type.capitalize() - self.function_callback( - f"{get_colored_text('[chain/start]', color='green')} " - + get_bolded_text(f"[{crumbs}] Entering {run_type} run with input:\n") - + f"{try_json_stringify(run.inputs, '[inputs]')}" - ) - - def _on_chain_end(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - run_type = run.run_type.capitalize() - self.function_callback( - f"{get_colored_text('[chain/end]', color='blue')} " - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] Exiting {run_type} run with output:\n" - ) - + f"{try_json_stringify(run.outputs, '[outputs]')}" - ) - - def _on_chain_error(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - run_type = run.run_type.capitalize() - self.function_callback( - f"{get_colored_text('[chain/error]', color='red')} " - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] {run_type} run errored with error:\n" - ) - + f"{try_json_stringify(run.error, '[error]')}" - ) - - def _on_llm_start(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - inputs = ( - {"prompts": [p.strip() for p in run.inputs["prompts"]]} - if "prompts" in run.inputs - else run.inputs - ) - self.function_callback( - f"{get_colored_text('[llm/start]', color='green')} " - + get_bolded_text(f"[{crumbs}] Entering LLM run with input:\n") - + f"{try_json_stringify(inputs, '[inputs]')}" - ) - - def _on_llm_end(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - self.function_callback( - f"{get_colored_text('[llm/end]', color='blue')} " - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] Exiting LLM run with output:\n" - ) - + f"{try_json_stringify(run.outputs, '[response]')}" - ) - - def _on_llm_error(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - self.function_callback( - f"{get_colored_text('[llm/error]', color='red')} " - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] LLM run errored with error:\n" - ) - + f"{try_json_stringify(run.error, '[error]')}" - ) - - def _on_tool_start(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - self.function_callback( - f'{get_colored_text("[tool/start]", color="green")} ' - + get_bolded_text(f"[{crumbs}] Entering Tool run with input:\n") - + f'"{run.inputs["input"].strip()}"' - ) - - def _on_tool_end(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - if run.outputs: - self.function_callback( - f'{get_colored_text("[tool/end]", color="blue")} ' - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] Exiting Tool run with output:\n" - ) - + f'"{run.outputs["output"].strip()}"' - ) - - def _on_tool_error(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - self.function_callback( - f"{get_colored_text('[tool/error]', color='red')} " - + get_bolded_text(f"[{crumbs}] [{elapsed(run)}] ") - + f"Tool run errored with error:\n" - f"{run.error}" - ) - - -class ConsoleCallbackHandler(FunctionCallbackHandler): - """Tracer that prints to the console.""" - - name: str = "console_callback_handler" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(function=print, **kwargs) +__all__ = ["FunctionCallbackHandler", "ConsoleCallbackHandler"] diff --git a/libs/langchain/langchain/schema/callbacks/__init__.py b/libs/langchain/langchain/schema/callbacks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain/langchain/schema/callbacks/base.py b/libs/langchain/langchain/schema/callbacks/base.py new file mode 100644 index 00000000000..359496aa32a --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/base.py @@ -0,0 +1,598 @@ +"""Base callback handler that can be used to handle callbacks in langchain.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union +from uuid import UUID + +from tenacity import RetryCallState + +from langchain.schema.agent import AgentAction, AgentFinish +from langchain.schema.document import Document +from langchain.schema.messages import BaseMessage +from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult + + +class RetrieverManagerMixin: + """Mixin for Retriever callbacks.""" + + def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever errors.""" + + def on_retriever_end( + self, + documents: Sequence[Document], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever ends running.""" + + +class LLMManagerMixin: + """Mixin for LLM callbacks.""" + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on new LLM token. Only available when streaming is enabled. + + Args: + token (str): The new token. + chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, + containing content and other information. + """ + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM ends running.""" + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM errors.""" + + +class ChainManagerMixin: + """Mixin for chain callbacks.""" + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain ends running.""" + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain errors.""" + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action.""" + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent end.""" + + +class ToolManagerMixin: + """Mixin for tool callbacks.""" + + def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool ends running.""" + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool errors.""" + + +class CallbackManagerMixin: + """Mixin for callback manager.""" + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM starts running.""" + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement `on_chat_model_start`" + ) + + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever starts running.""" + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when chain starts running.""" + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when tool starts running.""" + + +class RunManagerMixin: + """Mixin for run manager.""" + + def on_text( + self, + text: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on arbitrary text.""" + + def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on a retry event.""" + + +class BaseCallbackHandler( + LLMManagerMixin, + ChainManagerMixin, + ToolManagerMixin, + RetrieverManagerMixin, + CallbackManagerMixin, + RunManagerMixin, +): + """Base callback handler that handles callbacks from LangChain.""" + + raise_error: bool = False + + run_inline: bool = False + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return False + + @property + def ignore_retry(self) -> bool: + """Whether to ignore retry callbacks.""" + return False + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return False + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return False + + @property + def ignore_retriever(self) -> bool: + """Whether to ignore retriever callbacks.""" + return False + + @property + def ignore_chat_model(self) -> bool: + """Whether to ignore chat model callbacks.""" + return False + + +class AsyncCallbackHandler(BaseCallbackHandler): + """Async callback handler that handles callbacks from LangChain.""" + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM starts running.""" + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement `on_chat_model_start`" + ) + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + + async def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM ends running.""" + + async def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM errors.""" + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run when chain starts running.""" + + async def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain ends running.""" + + async def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run when tool starts running.""" + + async def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool ends running.""" + + async def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + + async def on_text( + self, + text: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on arbitrary text.""" + + async def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on a retry event.""" + + async def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on agent action.""" + + async def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on agent end.""" + + async def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run on retriever start.""" + + async def on_retriever_end( + self, + documents: Sequence[Document], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on retriever end.""" + + async def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on retriever error.""" + + +T = TypeVar("T", bound="BaseCallbackManager") + + +class BaseCallbackManager(CallbackManagerMixin): + """Base callback manager that handles callbacks from LangChain.""" + + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + parent_run_id: Optional[UUID] = None, + *, + tags: Optional[List[str]] = None, + inheritable_tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize callback manager.""" + self.handlers: List[BaseCallbackHandler] = handlers + self.inheritable_handlers: List[BaseCallbackHandler] = ( + inheritable_handlers or [] + ) + self.parent_run_id: Optional[UUID] = parent_run_id + self.tags = tags or [] + self.inheritable_tags = inheritable_tags or [] + self.metadata = metadata or {} + self.inheritable_metadata = inheritable_metadata or {} + + def copy(self: T) -> T: + """Copy the callback manager.""" + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + @property + def is_async(self) -> bool: + """Whether the callback manager is async.""" + return False + + def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: + """Add a handler to the callback manager.""" + if handler not in self.handlers: + self.handlers.append(handler) + if inherit and handler not in self.inheritable_handlers: + self.inheritable_handlers.append(handler) + + def remove_handler(self, handler: BaseCallbackHandler) -> None: + """Remove a handler from the callback manager.""" + self.handlers.remove(handler) + self.inheritable_handlers.remove(handler) + + def set_handlers( + self, handlers: List[BaseCallbackHandler], inherit: bool = True + ) -> None: + """Set handlers as the only handlers on the callback manager.""" + self.handlers = [] + self.inheritable_handlers = [] + for handler in handlers: + self.add_handler(handler, inherit=inherit) + + def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: + """Set handler as the only handler on the callback manager.""" + self.set_handlers([handler], inherit=inherit) + + def add_tags(self, tags: List[str], inherit: bool = True) -> None: + for tag in tags: + if tag in self.tags: + self.remove_tags([tag]) + self.tags.extend(tags) + if inherit: + self.inheritable_tags.extend(tags) + + def remove_tags(self, tags: List[str]) -> None: + for tag in tags: + self.tags.remove(tag) + self.inheritable_tags.remove(tag) + + def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: + self.metadata.update(metadata) + if inherit: + self.inheritable_metadata.update(metadata) + + def remove_metadata(self, keys: List[str]) -> None: + for key in keys: + self.metadata.pop(key) + self.inheritable_metadata.pop(key) + + +Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] diff --git a/libs/langchain/langchain/schema/callbacks/manager.py b/libs/langchain/langchain/schema/callbacks/manager.py new file mode 100644 index 00000000000..0491f1fed74 --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/manager.py @@ -0,0 +1,2075 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import os +import uuid +from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager, contextmanager +from contextvars import ContextVar +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Coroutine, + Dict, + Generator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from uuid import UUID + +from langsmith import utils as ls_utils +from langsmith.run_helpers import get_run_tree_context +from tenacity import RetryCallState + +from langchain.schema import ( + AgentAction, + AgentFinish, + Document, + LLMResult, +) +from langchain.schema.callbacks.base import ( + BaseCallbackHandler, + BaseCallbackManager, + Callbacks, + ChainManagerMixin, + LLMManagerMixin, + RetrieverManagerMixin, + RunManagerMixin, + ToolManagerMixin, +) +from langchain.schema.callbacks.stdout import StdOutCallbackHandler +from langchain.schema.callbacks.tracers import run_collector +from langchain.schema.callbacks.tracers.langchain import ( + LangChainTracer, +) +from langchain.schema.callbacks.tracers.langchain_v1 import ( + LangChainTracerV1, + TracerSessionV1, +) +from langchain.schema.callbacks.tracers.stdout import ConsoleCallbackHandler +from langchain.schema.messages import BaseMessage, get_buffer_string +from langchain.schema.output import ChatGenerationChunk, GenerationChunk + +if TYPE_CHECKING: + from langsmith import Client as LangSmithClient + +logger = logging.getLogger(__name__) + +tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501 + "tracing_callback", default=None +) + +tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501 + "tracing_callback_v2", default=None +) +run_collector_var: ContextVar[ + Optional[run_collector.RunCollectorCallbackHandler] +] = ContextVar( # noqa: E501 + "run_collector", default=None +) + + +def _get_debug() -> bool: + from langchain.globals import get_debug + + return get_debug() + + +@contextmanager +def tracing_enabled( + session_name: str = "default", +) -> Generator[TracerSessionV1, None, None]: + """Get the Deprecated LangChainTracer in a context manager. + + Args: + session_name (str, optional): The name of the session. + Defaults to "default". + + Returns: + TracerSessionV1: The LangChainTracer session. + + Example: + >>> with tracing_enabled() as session: + ... # Use the LangChainTracer session + """ + cb = LangChainTracerV1() + session = cast(TracerSessionV1, cb.load_session(session_name)) + try: + tracing_callback_var.set(cb) + yield session + finally: + tracing_callback_var.set(None) + + +@contextmanager +def tracing_v2_enabled( + project_name: Optional[str] = None, + *, + example_id: Optional[Union[str, UUID]] = None, + tags: Optional[List[str]] = None, + client: Optional[LangSmithClient] = None, +) -> Generator[LangChainTracer, None, None]: + """Instruct LangChain to log all runs in context to LangSmith. + + Args: + project_name (str, optional): The name of the project. + Defaults to "default". + example_id (str or UUID, optional): The ID of the example. + Defaults to None. + tags (List[str], optional): The tags to add to the run. + Defaults to None. + + Returns: + None + + Example: + >>> with tracing_v2_enabled(): + ... # LangChain code will automatically be traced + + You can use this to fetch the LangSmith run URL: + + >>> with tracing_v2_enabled() as cb: + ... chain.invoke("foo") + ... run_url = cb.get_run_url() + """ + if isinstance(example_id, str): + example_id = UUID(example_id) + cb = LangChainTracer( + example_id=example_id, + project_name=project_name, + tags=tags, + client=client, + ) + try: + tracing_v2_callback_var.set(cb) + yield cb + finally: + tracing_v2_callback_var.set(None) + + +@contextmanager +def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]: + """Collect all run traces in context. + + Returns: + run_collector.RunCollectorCallbackHandler: The run collector callback handler. + + Example: + >>> with collect_runs() as runs_cb: + chain.invoke("foo") + run_id = runs_cb.traced_runs[0].id + """ + cb = run_collector.RunCollectorCallbackHandler() + run_collector_var.set(cb) + yield cb + run_collector_var.set(None) + + +def _get_trace_callbacks( + project_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None, +) -> Callbacks: + if _tracing_v2_is_enabled(): + project_name_ = project_name or _get_tracer_project() + tracer = tracing_v2_callback_var.get() or LangChainTracer( + project_name=project_name_, + example_id=example_id, + ) + if callback_manager is None: + cb = cast(Callbacks, [tracer]) + else: + if not any( + isinstance(handler, LangChainTracer) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(tracer, True) + # If it already has a LangChainTracer, we don't need to add another one. + # this would likely mess up the trace hierarchy. + cb = callback_manager + else: + cb = None + return cb + + +@contextmanager +def trace_as_chain_group( + group_name: str, + callback_manager: Optional[CallbackManager] = None, + *, + inputs: Optional[Dict[str, Any]] = None, + project_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, +) -> Generator[CallbackManagerForChainGroup, None, None]: + """Get a callback manager for a chain group in a context manager. + Useful for grouping different calls together as a single run even if + they aren't composed in a single chain. + + Args: + group_name (str): The name of the chain group. + callback_manager (CallbackManager, optional): The callback manager to use. + inputs (Dict[str, Any], optional): The inputs to the chain group. + project_name (str, optional): The name of the project. + Defaults to None. + example_id (str or UUID, optional): The ID of the example. + Defaults to None. + run_id (UUID, optional): The ID of the run. + tags (List[str], optional): The inheritable tags to apply to all runs. + Defaults to None. + + Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. + + Returns: + CallbackManagerForChainGroup: The callback manager for the chain group. + + Example: + .. code-block:: python + + llm_input = "Foo" + with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: + # Use the callback manager for the chain group + res = llm.predict(llm_input, callbacks=manager) + manager.on_chain_end({"output": res}) + """ # noqa: E501 + cb = _get_trace_callbacks( + project_name, example_id, callback_manager=callback_manager + ) + cm = CallbackManager.configure( + inheritable_callbacks=cb, + inheritable_tags=tags, + ) + + run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id) + child_cm = run_manager.get_child() + group_cm = CallbackManagerForChainGroup( + child_cm.handlers, + child_cm.inheritable_handlers, + child_cm.parent_run_id, + parent_run_manager=run_manager, + tags=child_cm.tags, + inheritable_tags=child_cm.inheritable_tags, + metadata=child_cm.metadata, + inheritable_metadata=child_cm.inheritable_metadata, + ) + try: + yield group_cm + except Exception as e: + if not group_cm.ended: + run_manager.on_chain_error(e) + raise e + else: + if not group_cm.ended: + run_manager.on_chain_end({}) + + +@asynccontextmanager +async def atrace_as_chain_group( + group_name: str, + callback_manager: Optional[AsyncCallbackManager] = None, + *, + inputs: Optional[Dict[str, Any]] = None, + project_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, +) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: + """Get an async callback manager for a chain group in a context manager. + Useful for grouping different async calls together as a single run even if + they aren't composed in a single chain. + + Args: + group_name (str): The name of the chain group. + callback_manager (AsyncCallbackManager, optional): The async callback manager to use, + which manages tracing and other callback behavior. + project_name (str, optional): The name of the project. + Defaults to None. + example_id (str or UUID, optional): The ID of the example. + Defaults to None. + run_id (UUID, optional): The ID of the run. + tags (List[str], optional): The inheritable tags to apply to all runs. + Defaults to None. + Returns: + AsyncCallbackManager: The async callback manager for the chain group. + + Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. + + Example: + .. code-block:: python + + llm_input = "Foo" + async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: + # Use the async callback manager for the chain group + res = await llm.apredict(llm_input, callbacks=manager) + await manager.on_chain_end({"output": res}) + """ # noqa: E501 + cb = _get_trace_callbacks( + project_name, example_id, callback_manager=callback_manager + ) + cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) + + run_manager = await cm.on_chain_start( + {"name": group_name}, inputs or {}, run_id=run_id + ) + child_cm = run_manager.get_child() + group_cm = AsyncCallbackManagerForChainGroup( + child_cm.handlers, + child_cm.inheritable_handlers, + child_cm.parent_run_id, + parent_run_manager=run_manager, + tags=child_cm.tags, + inheritable_tags=child_cm.inheritable_tags, + metadata=child_cm.metadata, + inheritable_metadata=child_cm.inheritable_metadata, + ) + try: + yield group_cm + except Exception as e: + if not group_cm.ended: + await run_manager.on_chain_error(e) + raise e + else: + if not group_cm.ended: + await run_manager.on_chain_end({}) + + +def handle_event( + handlers: List[BaseCallbackHandler], + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + """Generic event handler for CallbackManager. + + Note: This function is used by langserve to handle events. + + Args: + handlers: The list of handlers that will handle the event + event_name: The name of the event (e.g., "on_llm_start") + ignore_condition_name: Name of the attribute defined on handler + that if True will cause the handler to be skipped for the given event + *args: The arguments to pass to the event handler + **kwargs: The keyword arguments to pass to the event handler + """ + coros: List[Coroutine[Any, Any, Any]] = [] + + try: + message_strings: Optional[List[str]] = None + for handler in handlers: + try: + if ignore_condition_name is None or not getattr( + handler, ignore_condition_name + ): + event = getattr(handler, event_name)(*args, **kwargs) + if asyncio.iscoroutine(event): + coros.append(event) + except NotImplementedError as e: + if event_name == "on_chat_model_start": + if message_strings is None: + message_strings = [get_buffer_string(m) for m in args[1]] + handle_event( + [handler], + "on_llm_start", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + else: + handler_name = handler.__class__.__name__ + logger.warning( + f"NotImplementedError in {handler_name}.{event_name}" + f" callback: {repr(e)}" + ) + except Exception as e: + logger.warning( + f"Error in {handler.__class__.__name__}.{event_name} callback:" + f" {repr(e)}" + ) + if handler.raise_error: + raise e + finally: + if coros: + try: + # Raises RuntimeError if there is no current event loop. + asyncio.get_running_loop() + loop_running = True + except RuntimeError: + loop_running = False + + if loop_running: + # If we try to submit this coroutine to the running loop + # we end up in a deadlock, as we'd have gotten here from a + # running coroutine, which we cannot interrupt to run this one. + # The solution is to create a new loop in a new thread. + with ThreadPoolExecutor(1) as executor: + executor.submit(_run_coros, coros).result() + else: + _run_coros(coros) + + +def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: + if hasattr(asyncio, "Runner"): + # Python 3.11+ + # Run the coroutines in a new event loop, taking care to + # - install signal handlers + # - run pending tasks scheduled by `coros` + # - close asyncgens and executors + # - close the loop + with asyncio.Runner() as runner: + # Run the coroutine, get the result + for coro in coros: + runner.run(coro) + + # Run pending tasks scheduled by coros until they are all done + while pending := asyncio.all_tasks(runner.get_loop()): + runner.run(asyncio.wait(pending)) + else: + # Before Python 3.11 we need to run each coroutine in a new event loop + # as the Runner api is not available. + for coro in coros: + asyncio.run(coro) + + +async def _ahandle_event_for_handler( + handler: BaseCallbackHandler, + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + try: + if ignore_condition_name is None or not getattr(handler, ignore_condition_name): + event = getattr(handler, event_name) + if asyncio.iscoroutinefunction(event): + await event(*args, **kwargs) + else: + if handler.run_inline: + event(*args, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, functools.partial(event, *args, **kwargs) + ) + except NotImplementedError as e: + if event_name == "on_chat_model_start": + message_strings = [get_buffer_string(m) for m in args[1]] + await _ahandle_event_for_handler( + handler, + "on_llm_start", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + else: + logger.warning( + f"NotImplementedError in {handler.__class__.__name__}.{event_name}" + f" callback: {repr(e)}" + ) + except Exception as e: + logger.warning( + f"Error in {handler.__class__.__name__}.{event_name} callback:" + f" {repr(e)}" + ) + if handler.raise_error: + raise e + + +async def ahandle_event( + handlers: List[BaseCallbackHandler], + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + """Generic event handler for AsyncCallbackManager. + + Note: This function is used by langserve to handle events. + + Args: + handlers: The list of handlers that will handle the event + event_name: The name of the event (e.g., "on_llm_start") + ignore_condition_name: Name of the attribute defined on handler + that if True will cause the handler to be skipped for the given event + *args: The arguments to pass to the event handler + **kwargs: The keyword arguments to pass to the event handler + """ + for handler in [h for h in handlers if h.run_inline]: + await _ahandle_event_for_handler( + handler, event_name, ignore_condition_name, *args, **kwargs + ) + await asyncio.gather( + *( + _ahandle_event_for_handler( + handler, event_name, ignore_condition_name, *args, **kwargs + ) + for handler in handlers + if not handler.run_inline + ) + ) + + +BRM = TypeVar("BRM", bound="BaseRunManager") + + +class BaseRunManager(RunManagerMixin): + """Base class for run manager (a bound callback manager).""" + + def __init__( + self, + *, + run_id: UUID, + handlers: List[BaseCallbackHandler], + inheritable_handlers: List[BaseCallbackHandler], + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + inheritable_tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the run manager. + + Args: + run_id (UUID): The ID of the run. + handlers (List[BaseCallbackHandler]): The list of handlers. + inheritable_handlers (List[BaseCallbackHandler]): + The list of inheritable handlers. + parent_run_id (UUID, optional): The ID of the parent run. + Defaults to None. + tags (Optional[List[str]]): The list of tags. + inheritable_tags (Optional[List[str]]): The list of inheritable tags. + metadata (Optional[Dict[str, Any]]): The metadata. + inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata. + """ + self.run_id = run_id + self.handlers = handlers + self.inheritable_handlers = inheritable_handlers + self.parent_run_id = parent_run_id + self.tags = tags or [] + self.inheritable_tags = inheritable_tags or [] + self.metadata = metadata or {} + self.inheritable_metadata = inheritable_metadata or {} + + @classmethod + def get_noop_manager(cls: Type[BRM]) -> BRM: + """Return a manager that doesn't perform any operations. + + Returns: + BaseRunManager: The noop manager. + """ + return cls( + run_id=uuid.uuid4(), + handlers=[], + inheritable_handlers=[], + tags=[], + inheritable_tags=[], + metadata={}, + inheritable_metadata={}, + ) + + +class RunManager(BaseRunManager): + """Sync Run Manager.""" + + def on_text( + self, + text: str, + **kwargs: Any, + ) -> Any: + """Run when text is received. + + Args: + text (str): The received text. + + Returns: + Any: The result of the callback. + """ + handle_event( + self.handlers, + "on_text", + None, + text, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_retry( + self, + retry_state: RetryCallState, + **kwargs: Any, + ) -> None: + handle_event( + self.handlers, + "on_retry", + "ignore_retry", + retry_state, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class ParentRunManager(RunManager): + """Sync Parent Run Manager.""" + + def get_child(self, tag: Optional[str] = None) -> CallbackManager: + """Get a child callback manager. + + Args: + tag (str, optional): The tag for the child callback manager. + Defaults to None. + + Returns: + CallbackManager: The child callback manager. + """ + manager = CallbackManager(handlers=[], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + manager.add_metadata(self.inheritable_metadata) + if tag is not None: + manager.add_tags([tag], False) + return manager + + +class AsyncRunManager(BaseRunManager): + """Async Run Manager.""" + + async def on_text( + self, + text: str, + **kwargs: Any, + ) -> Any: + """Run when text is received. + + Args: + text (str): The received text. + + Returns: + Any: The result of the callback. + """ + await ahandle_event( + self.handlers, + "on_text", + None, + text, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_retry( + self, + retry_state: RetryCallState, + **kwargs: Any, + ) -> None: + await ahandle_event( + self.handlers, + "on_retry", + "ignore_retry", + retry_state, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncParentRunManager(AsyncRunManager): + """Async Parent Run Manager.""" + + def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: + """Get a child callback manager. + + Args: + tag (str, optional): The tag for the child callback manager. + Defaults to None. + + Returns: + AsyncCallbackManager: The child callback manager. + """ + manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + manager.add_metadata(self.inheritable_metadata) + if tag is not None: + manager.add_tags([tag], False) + return manager + + +class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): + """Callback manager for LLM run.""" + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM generates a new token. + + Args: + token (str): The new token. + """ + handle_event( + self.handlers, + "on_llm_new_token", + "ignore_llm", + token=token, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + chunk=chunk, + **kwargs, + ) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running. + + Args: + response (LLMResult): The LLM result. + """ + handle_event( + self.handlers, + "on_llm_end", + "ignore_llm", + response, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_llm_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when LLM errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + handle_event( + self.handlers, + "on_llm_error", + "ignore_llm", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): + """Async callback manager for LLM run.""" + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM generates a new token. + + Args: + token (str): The new token. + """ + await ahandle_event( + self.handlers, + "on_llm_new_token", + "ignore_llm", + token, + chunk=chunk, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running. + + Args: + response (LLMResult): The LLM result. + """ + await ahandle_event( + self.handlers, + "on_llm_end", + "ignore_llm", + response, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_llm_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when LLM errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + await ahandle_event( + self.handlers, + "on_llm_error", + "ignore_llm", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): + """Callback manager for chain run.""" + + def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: + """Run when chain ends running. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + handle_event( + self.handlers, + "on_chain_end", + "ignore_chain", + outputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + handle_event( + self.handlers, + "on_chain_error", + "ignore_chain", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run when agent action is received. + + Args: + action (AgentAction): The agent action. + + Returns: + Any: The result of the callback. + """ + handle_event( + self.handlers, + "on_agent_action", + "ignore_agent", + action, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run when agent finish is received. + + Args: + finish (AgentFinish): The agent finish. + + Returns: + Any: The result of the callback. + """ + handle_event( + self.handlers, + "on_agent_finish", + "ignore_agent", + finish, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): + """Async callback manager for chain run.""" + + async def on_chain_end( + self, outputs: Union[Dict[str, Any], Any], **kwargs: Any + ) -> None: + """Run when chain ends running. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + await ahandle_event( + self.handlers, + "on_chain_end", + "ignore_chain", + outputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + await ahandle_event( + self.handlers, + "on_chain_error", + "ignore_chain", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run when agent action is received. + + Args: + action (AgentAction): The agent action. + + Returns: + Any: The result of the callback. + """ + await ahandle_event( + self.handlers, + "on_agent_action", + "ignore_agent", + action, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run when agent finish is received. + + Args: + finish (AgentFinish): The agent finish. + + Returns: + Any: The result of the callback. + """ + await ahandle_event( + self.handlers, + "on_agent_finish", + "ignore_agent", + finish, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): + """Callback manager for tool run.""" + + def on_tool_end( + self, + output: str, + **kwargs: Any, + ) -> None: + """Run when tool ends running. + + Args: + output (str): The output of the tool. + """ + handle_event( + self.handlers, + "on_tool_end", + "ignore_agent", + output, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_tool_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when tool errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + handle_event( + self.handlers, + "on_tool_error", + "ignore_agent", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): + """Async callback manager for tool run.""" + + async def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running. + + Args: + output (str): The output of the tool. + """ + await ahandle_event( + self.handlers, + "on_tool_end", + "ignore_agent", + output, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_tool_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when tool errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + await ahandle_event( + self.handlers, + "on_tool_error", + "ignore_agent", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): + """Callback manager for retriever run.""" + + def on_retriever_end( + self, + documents: Sequence[Document], + **kwargs: Any, + ) -> None: + """Run when retriever ends running.""" + handle_event( + self.handlers, + "on_retriever_end", + "ignore_retriever", + documents, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_retriever_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when retriever errors.""" + handle_event( + self.handlers, + "on_retriever_error", + "ignore_retriever", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncCallbackManagerForRetrieverRun( + AsyncParentRunManager, + RetrieverManagerMixin, +): + """Async callback manager for retriever run.""" + + async def on_retriever_end( + self, documents: Sequence[Document], **kwargs: Any + ) -> None: + """Run when retriever ends running.""" + await ahandle_event( + self.handlers, + "on_retriever_end", + "ignore_retriever", + documents, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_retriever_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when retriever errors.""" + await ahandle_event( + self.handlers, + "on_retriever_error", + "ignore_retriever", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class CallbackManager(BaseCallbackManager): + """Callback manager that handles callbacks from LangChain.""" + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + **kwargs: Any, + ) -> List[CallbackManagerForLLMRun]: + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + prompts (List[str]): The list of prompts. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + List[CallbackManagerForLLMRun]: A callback manager for each + prompt as an LLM run. + """ + managers = [] + for prompt in prompts: + run_id_ = uuid.uuid4() + handle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + managers.append( + CallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + ) + + return managers + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any, + ) -> List[CallbackManagerForLLMRun]: + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + messages (List[List[BaseMessage]]): The list of messages. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + List[CallbackManagerForLLMRun]: A callback manager for each + list of messages as an LLM run. + """ + + managers = [] + for message_list in messages: + run_id_ = uuid.uuid4() + handle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + [message_list], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + managers.append( + CallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + ) + + return managers + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Union[Dict[str, Any], Any], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForChainRun: + """Run when chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Union[Dict[str, Any], Any]): The inputs to the chain. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + CallbackManagerForChainRun: The callback manager for the chain run. + """ + if run_id is None: + run_id = uuid.uuid4() + handle_event( + self.handlers, + "on_chain_start", + "ignore_chain", + serialized, + inputs, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return CallbackManagerForChainRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForToolRun: + """Run when tool starts running. + + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input to the tool. + run_id (UUID, optional): The ID of the run. Defaults to None. + parent_run_id (UUID, optional): The ID of the parent run. Defaults to None. + + Returns: + CallbackManagerForToolRun: The callback manager for the tool run. + """ + if run_id is None: + run_id = uuid.uuid4() + + handle_event( + self.handlers, + "on_tool_start", + "ignore_agent", + serialized, + input_str, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return CallbackManagerForToolRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForRetrieverRun: + """Run when retriever starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + handle_event( + self.handlers, + "on_retriever_start", + "ignore_retriever", + serialized, + query, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return CallbackManagerForRetrieverRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + @classmethod + def configure( + cls, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, + ) -> CallbackManager: + """Configure the callback manager. + + Args: + inheritable_callbacks (Optional[Callbacks], optional): The inheritable + callbacks. Defaults to None. + local_callbacks (Optional[Callbacks], optional): The local callbacks. + Defaults to None. + verbose (bool, optional): Whether to enable verbose mode. Defaults to False. + inheritable_tags (Optional[List[str]], optional): The inheritable tags. + Defaults to None. + local_tags (Optional[List[str]], optional): The local tags. + Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. + + Returns: + CallbackManager: The configured callback manager. + """ + return _configure( + cls, + inheritable_callbacks, + local_callbacks, + verbose, + inheritable_tags, + local_tags, + inheritable_metadata, + local_metadata, + ) + + +class CallbackManagerForChainGroup(CallbackManager): + """Callback manager for the chain group.""" + + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + parent_run_id: Optional[UUID] = None, + *, + parent_run_manager: CallbackManagerForChainRun, + **kwargs: Any, + ) -> None: + super().__init__( + handlers, + inheritable_handlers, + parent_run_id, + **kwargs, + ) + self.parent_run_manager = parent_run_manager + self.ended = False + + def copy(self) -> CallbackManagerForChainGroup: + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + parent_run_manager=self.parent_run_manager, + ) + + def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: + """Run when traced chain group ends. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + self.ended = True + return self.parent_run_manager.on_chain_end(outputs, **kwargs) + + def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + self.ended = True + return self.parent_run_manager.on_chain_error(error, **kwargs) + + +class AsyncCallbackManager(BaseCallbackManager): + """Async callback manager that handles callbacks from LangChain.""" + + @property + def is_async(self) -> bool: + """Return whether the handler is async.""" + return True + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + **kwargs: Any, + ) -> List[AsyncCallbackManagerForLLMRun]: + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + prompts (List[str]): The list of prompts. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + List[AsyncCallbackManagerForLLMRun]: The list of async + callback managers, one for each LLM Run corresponding + to each prompt. + """ + + tasks = [] + managers = [] + + for prompt in prompts: + run_id_ = uuid.uuid4() + + tasks.append( + ahandle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + ) + + managers.append( + AsyncCallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + ) + + await asyncio.gather(*tasks) + + return managers + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any, + ) -> List[AsyncCallbackManagerForLLMRun]: + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + messages (List[List[BaseMessage]]): The list of messages. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + List[AsyncCallbackManagerForLLMRun]: The list of + async callback managers, one for each LLM Run + corresponding to each inner message list. + """ + tasks = [] + managers = [] + + for message_list in messages: + run_id_ = uuid.uuid4() + + tasks.append( + ahandle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + [message_list], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + ) + + managers.append( + AsyncCallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + ) + + await asyncio.gather(*tasks) + return managers + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Union[Dict[str, Any], Any], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForChainRun: + """Run when chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Union[Dict[str, Any], Any]): The inputs to the chain. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + AsyncCallbackManagerForChainRun: The async callback manager + for the chain run. + """ + if run_id is None: + run_id = uuid.uuid4() + + await ahandle_event( + self.handlers, + "on_chain_start", + "ignore_chain", + serialized, + inputs, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return AsyncCallbackManagerForChainRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForToolRun: + """Run when tool starts running. + + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input to the tool. + run_id (UUID, optional): The ID of the run. Defaults to None. + parent_run_id (UUID, optional): The ID of the parent run. + Defaults to None. + + Returns: + AsyncCallbackManagerForToolRun: The async callback manager + for the tool run. + """ + if run_id is None: + run_id = uuid.uuid4() + + await ahandle_event( + self.handlers, + "on_tool_start", + "ignore_agent", + serialized, + input_str, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return AsyncCallbackManagerForToolRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + async def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForRetrieverRun: + """Run when retriever starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + await ahandle_event( + self.handlers, + "on_retriever_start", + "ignore_retriever", + serialized, + query, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return AsyncCallbackManagerForRetrieverRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + @classmethod + def configure( + cls, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, + ) -> AsyncCallbackManager: + """Configure the async callback manager. + + Args: + inheritable_callbacks (Optional[Callbacks], optional): The inheritable + callbacks. Defaults to None. + local_callbacks (Optional[Callbacks], optional): The local callbacks. + Defaults to None. + verbose (bool, optional): Whether to enable verbose mode. Defaults to False. + inheritable_tags (Optional[List[str]], optional): The inheritable tags. + Defaults to None. + local_tags (Optional[List[str]], optional): The local tags. + Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. + + Returns: + AsyncCallbackManager: The configured async callback manager. + """ + return _configure( + cls, + inheritable_callbacks, + local_callbacks, + verbose, + inheritable_tags, + local_tags, + inheritable_metadata, + local_metadata, + ) + + +class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): + """Async callback manager for the chain group.""" + + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + parent_run_id: Optional[UUID] = None, + *, + parent_run_manager: AsyncCallbackManagerForChainRun, + **kwargs: Any, + ) -> None: + super().__init__( + handlers, + inheritable_handlers, + parent_run_id, + **kwargs, + ) + self.parent_run_manager = parent_run_manager + self.ended = False + + def copy(self) -> AsyncCallbackManagerForChainGroup: + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + parent_run_manager=self.parent_run_manager, + ) + + async def on_chain_end( + self, outputs: Union[Dict[str, Any], Any], **kwargs: Any + ) -> None: + """Run when traced chain group ends. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + self.ended = True + await self.parent_run_manager.on_chain_end(outputs, **kwargs) + + async def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + self.ended = True + await self.parent_run_manager.on_chain_error(error, **kwargs) + + +T = TypeVar("T", CallbackManager, AsyncCallbackManager) + + +def env_var_is_set(env_var: str) -> bool: + """Check if an environment variable is set. + + Args: + env_var (str): The name of the environment variable. + + Returns: + bool: True if the environment variable is set, False otherwise. + """ + return env_var in os.environ and os.environ[env_var] not in ( + "", + "0", + "false", + "False", + ) + + +def _tracing_v2_is_enabled() -> bool: + return ( + env_var_is_set("LANGCHAIN_TRACING_V2") + or tracing_v2_callback_var.get() is not None + or get_run_tree_context() is not None + ) + + +def _get_tracer_project() -> str: + run_tree = get_run_tree_context() + return getattr( + run_tree, + "session_name", + getattr( + # Note, if people are trying to nest @traceable functions and the + # tracing_v2_enabled context manager, this will likely mess up the + # tree structure. + tracing_v2_callback_var.get(), + "project", + # Have to set this to a string even though it always will return + # a string because `get_tracer_project` technically can return + # None, but only when a specific argument is supplied. + # Therefore, this just tricks the mypy type checker + str(ls_utils.get_tracer_project()), + ), + ) + + +_configure_hooks: List[ + Tuple[ + ContextVar[Optional[BaseCallbackHandler]], + bool, + Optional[Type[BaseCallbackHandler]], + Optional[str], + ] +] = [] + +H = TypeVar("H", bound=BaseCallbackHandler, covariant=True) + + +def register_configure_hook( + context_var: ContextVar[Optional[Any]], + inheritable: bool, + handle_class: Optional[Type[BaseCallbackHandler]] = None, + env_var: Optional[str] = None, +) -> None: + if env_var is not None and handle_class is None: + raise ValueError( + "If env_var is set, handle_class must also be set to a non-None value." + ) + _configure_hooks.append( + ( + # the typings of ContextVar do not have the generic arg set as covariant + # so we have to cast it + cast(ContextVar[Optional[BaseCallbackHandler]], context_var), + inheritable, + handle_class, + env_var, + ) + ) + + +register_configure_hook(run_collector_var, False) + + +def _configure( + callback_manager_cls: Type[T], + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, +) -> T: + """Configure the callback manager. + + Args: + callback_manager_cls (Type[T]): The callback manager class. + inheritable_callbacks (Optional[Callbacks], optional): The inheritable + callbacks. Defaults to None. + local_callbacks (Optional[Callbacks], optional): The local callbacks. + Defaults to None. + verbose (bool, optional): Whether to enable verbose mode. Defaults to False. + inheritable_tags (Optional[List[str]], optional): The inheritable tags. + Defaults to None. + local_tags (Optional[List[str]], optional): The local tags. Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. + + Returns: + T: The configured callback manager. + """ + run_tree = get_run_tree_context() + parent_run_id = None if run_tree is None else getattr(run_tree, "id") + callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id) + if inheritable_callbacks or local_callbacks: + if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: + inheritable_callbacks_ = inheritable_callbacks or [] + callback_manager = callback_manager_cls( + handlers=inheritable_callbacks_.copy(), + inheritable_handlers=inheritable_callbacks_.copy(), + parent_run_id=parent_run_id, + ) + else: + callback_manager = callback_manager_cls( + handlers=inheritable_callbacks.handlers.copy(), + inheritable_handlers=inheritable_callbacks.inheritable_handlers.copy(), + parent_run_id=inheritable_callbacks.parent_run_id, + tags=inheritable_callbacks.tags.copy(), + inheritable_tags=inheritable_callbacks.inheritable_tags.copy(), + metadata=inheritable_callbacks.metadata.copy(), + inheritable_metadata=inheritable_callbacks.inheritable_metadata.copy(), + ) + local_handlers_ = ( + local_callbacks + if isinstance(local_callbacks, list) + else (local_callbacks.handlers if local_callbacks else []) + ) + for handler in local_handlers_: + callback_manager.add_handler(handler, False) + if inheritable_tags or local_tags: + callback_manager.add_tags(inheritable_tags or []) + callback_manager.add_tags(local_tags or [], False) + if inheritable_metadata or local_metadata: + callback_manager.add_metadata(inheritable_metadata or {}) + callback_manager.add_metadata(local_metadata or {}, False) + + tracer = tracing_callback_var.get() + tracing_enabled_ = ( + env_var_is_set("LANGCHAIN_TRACING") + or tracer is not None + or env_var_is_set("LANGCHAIN_HANDLER") + ) + + tracer_v2 = tracing_v2_callback_var.get() + tracing_v2_enabled_ = _tracing_v2_is_enabled() + tracer_project = _get_tracer_project() + debug = _get_debug() + if verbose or debug or tracing_enabled_ or tracing_v2_enabled_: + if verbose and not any( + isinstance(handler, StdOutCallbackHandler) + for handler in callback_manager.handlers + ): + if debug: + pass + else: + callback_manager.add_handler(StdOutCallbackHandler(), False) + if debug and not any( + isinstance(handler, ConsoleCallbackHandler) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(ConsoleCallbackHandler(), True) + if tracing_enabled_ and not any( + isinstance(handler, LangChainTracerV1) + for handler in callback_manager.handlers + ): + if tracer: + callback_manager.add_handler(tracer, True) + else: + handler = LangChainTracerV1() + handler.load_session(tracer_project) + callback_manager.add_handler(handler, True) + if tracing_v2_enabled_ and not any( + isinstance(handler, LangChainTracer) + for handler in callback_manager.handlers + ): + if tracer_v2: + callback_manager.add_handler(tracer_v2, True) + else: + try: + handler = LangChainTracer(project_name=tracer_project) + callback_manager.add_handler(handler, True) + except Exception as e: + logger.warning( + "Unable to load requested LangChainTracer." + " To disable this warning," + " unset the LANGCHAIN_TRACING_V2 environment variables.", + e, + ) + for var, inheritable, handler_class, env_var in _configure_hooks: + create_one = ( + env_var is not None + and env_var_is_set(env_var) + and handler_class is not None + ) + if var.get() is not None or create_one: + var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)() + if handler_class is None: + if not any( + handler is var_handler # direct pointer comparison + for handler in callback_manager.handlers + ): + callback_manager.add_handler(var_handler, inheritable) + else: + if not any( + isinstance(handler, handler_class) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(var_handler, inheritable) + return callback_manager diff --git a/libs/langchain/langchain/schema/callbacks/stdout.py b/libs/langchain/langchain/schema/callbacks/stdout.py new file mode 100644 index 00000000000..63e71a30dd5 --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/stdout.py @@ -0,0 +1,97 @@ +"""Callback Handler that prints to std out.""" +from typing import Any, Dict, List, Optional + +from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema.callbacks.base import BaseCallbackHandler +from langchain.utils.input import print_text + + +class StdOutCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def __init__(self, color: Optional[str] = None) -> None: + """Initialize callback handler.""" + self.color = color + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we finished a chain.""" + print("\n\033[1m> Finished chain.\033[0m") + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Do nothing.""" + pass + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + """Run on agent action.""" + print_text(action.log, color=color or self.color) + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + if observation_prefix is not None: + print_text(f"\n{observation_prefix}") + print_text(output, color=color or self.color) + if llm_prefix is not None: + print_text(f"\n{llm_prefix}") + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Any, + ) -> None: + """Run when agent ends.""" + print_text(text, color=color or self.color, end=end) + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + """Run on agent end.""" + print_text(finish.log, color=color or self.color, end="\n") diff --git a/libs/langchain/langchain/schema/callbacks/tracers/__init__.py b/libs/langchain/langchain/schema/callbacks/tracers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain/langchain/schema/callbacks/tracers/base.py b/libs/langchain/langchain/schema/callbacks/tracers/base.py new file mode 100644 index 00000000000..4f136a1911d --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/base.py @@ -0,0 +1,537 @@ +"""Base interfaces for tracing runs.""" +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional, Sequence, Union, cast +from uuid import UUID + +from tenacity import RetryCallState + +from langchain.load.dump import dumpd +from langchain.schema.callbacks.base import BaseCallbackHandler +from langchain.schema.callbacks.tracers.schemas import Run +from langchain.schema.document import Document +from langchain.schema.output import ( + ChatGeneration, + ChatGenerationChunk, + GenerationChunk, + LLMResult, +) + +logger = logging.getLogger(__name__) + + +class TracerException(Exception): + """Base class for exceptions in tracers module.""" + + +class BaseTracer(BaseCallbackHandler, ABC): + """Base interface for tracers.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.run_map: Dict[str, Run] = {} + + @staticmethod + def _add_child_run( + parent_run: Run, + child_run: Run, + ) -> None: + """Add child run to a chain run or tool run.""" + parent_run.child_runs.append(child_run) + + @abstractmethod + def _persist_run(self, run: Run) -> None: + """Persist a run.""" + + def _start_trace(self, run: Run) -> None: + """Start a trace for a run.""" + if run.parent_run_id: + parent_run = self.run_map.get(str(run.parent_run_id)) + if parent_run: + self._add_child_run(parent_run, run) + parent_run.child_execution_order = max( + parent_run.child_execution_order, run.child_execution_order + ) + else: + logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") + self.run_map[str(run.id)] = run + self._on_run_create(run) + + def _end_trace(self, run: Run) -> None: + """End a trace for a run.""" + if not run.parent_run_id: + self._persist_run(run) + else: + parent_run = self.run_map.get(str(run.parent_run_id)) + if parent_run is None: + logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") + elif ( + run.child_execution_order is not None + and parent_run.child_execution_order is not None + and run.child_execution_order > parent_run.child_execution_order + ): + parent_run.child_execution_order = run.child_execution_order + self.run_map.pop(str(run.id)) + self._on_run_update(run) + + def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int: + """Get the execution order for a run.""" + if parent_run_id is None: + return 1 + + parent_run = self.run_map.get(parent_run_id) + if parent_run is None: + logger.debug(f"Parent run with UUID {parent_run_id} not found.") + return 1 + if parent_run.child_execution_order is None: + raise TracerException( + f"Parent run with UUID {parent_run_id} has no child execution order." + ) + + return parent_run.child_execution_order + 1 + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + """Start a trace for an LLM run.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + llm_run = Run( + id=run_id, + parent_run_id=parent_run_id, + serialized=serialized, + inputs={"prompts": prompts}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + run_type="llm", + tags=tags or [], + name=name, + ) + self._start_trace(llm_run) + self._on_llm_start(llm_run) + return llm_run + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Run: + """Run on new LLM token. Only available when streaming is enabled.""" + if not run_id: + raise TracerException("No run_id provided for on_llm_new_token callback.") + + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or llm_run.run_type != "llm": + raise TracerException(f"No LLM Run found to be traced for {run_id}") + event_kwargs: Dict[str, Any] = {"token": token} + if chunk: + event_kwargs["chunk"] = chunk + llm_run.events.append( + { + "name": "new_token", + "time": datetime.utcnow(), + "kwargs": event_kwargs, + }, + ) + self._on_llm_new_token(llm_run, token, chunk) + return llm_run + + def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + **kwargs: Any, + ) -> Run: + if not run_id: + raise TracerException("No run_id provided for on_retry callback.") + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None: + raise TracerException("No Run found to be traced for on_retry") + retry_d: Dict[str, Any] = { + "slept": retry_state.idle_for, + "attempt": retry_state.attempt_number, + } + if retry_state.outcome is None: + retry_d["outcome"] = "N/A" + elif retry_state.outcome.failed: + retry_d["outcome"] = "failed" + exception = retry_state.outcome.exception() + retry_d["exception"] = str(exception) + retry_d["exception_type"] = exception.__class__.__name__ + else: + retry_d["outcome"] = "success" + retry_d["result"] = str(retry_state.outcome.result()) + llm_run.events.append( + { + "name": "retry", + "time": datetime.utcnow(), + "kwargs": retry_d, + }, + ) + return llm_run + + def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: + """End a trace for an LLM run.""" + if not run_id: + raise TracerException("No run_id provided for on_llm_end callback.") + + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or llm_run.run_type != "llm": + raise TracerException(f"No LLM Run found to be traced for {run_id}") + llm_run.outputs = response.dict() + for i, generations in enumerate(response.generations): + for j, generation in enumerate(generations): + output_generation = llm_run.outputs["generations"][i][j] + if "message" in output_generation: + output_generation["message"] = dumpd( + cast(ChatGeneration, generation).message + ) + llm_run.end_time = datetime.utcnow() + llm_run.events.append({"name": "end", "time": llm_run.end_time}) + self._end_trace(llm_run) + self._on_llm_end(llm_run) + return llm_run + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> Run: + """Handle an error for an LLM run.""" + if not run_id: + raise TracerException("No run_id provided for on_llm_error callback.") + + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or llm_run.run_type != "llm": + raise TracerException(f"No LLM Run found to be traced for {run_id}") + llm_run.error = repr(error) + llm_run.end_time = datetime.utcnow() + llm_run.events.append({"name": "error", "time": llm_run.end_time}) + self._end_trace(llm_run) + self._on_chain_error(llm_run) + return llm_run + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + run_type: Optional[str] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + """Start a trace for a chain run.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + chain_run = Run( + id=run_id, + parent_run_id=parent_run_id, + serialized=serialized, + inputs=inputs if isinstance(inputs, dict) else {"input": inputs}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + child_runs=[], + run_type=run_type or "chain", + name=name, + tags=tags or [], + ) + self._start_trace(chain_run) + self._on_chain_start(chain_run) + return chain_run + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + inputs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Run: + """End a trace for a chain run.""" + if not run_id: + raise TracerException("No run_id provided for on_chain_end callback.") + chain_run = self.run_map.get(str(run_id)) + if chain_run is None: + raise TracerException(f"No chain Run found to be traced for {run_id}") + + chain_run.outputs = ( + outputs if isinstance(outputs, dict) else {"output": outputs} + ) + chain_run.end_time = datetime.utcnow() + chain_run.events.append({"name": "end", "time": chain_run.end_time}) + if inputs is not None: + chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} + self._end_trace(chain_run) + self._on_chain_end(chain_run) + return chain_run + + def on_chain_error( + self, + error: BaseException, + *, + inputs: Optional[Dict[str, Any]] = None, + run_id: UUID, + **kwargs: Any, + ) -> Run: + """Handle an error for a chain run.""" + if not run_id: + raise TracerException("No run_id provided for on_chain_error callback.") + chain_run = self.run_map.get(str(run_id)) + if chain_run is None: + raise TracerException(f"No chain Run found to be traced for {run_id}") + + chain_run.error = repr(error) + chain_run.end_time = datetime.utcnow() + chain_run.events.append({"name": "error", "time": chain_run.end_time}) + if inputs is not None: + chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} + self._end_trace(chain_run) + self._on_chain_error(chain_run) + return chain_run + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + """Start a trace for a tool run.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + tool_run = Run( + id=run_id, + parent_run_id=parent_run_id, + serialized=serialized, + inputs={"input": input_str}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + child_runs=[], + run_type="tool", + tags=tags or [], + name=name, + ) + self._start_trace(tool_run) + self._on_tool_start(tool_run) + return tool_run + + def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run: + """End a trace for a tool run.""" + if not run_id: + raise TracerException("No run_id provided for on_tool_end callback.") + tool_run = self.run_map.get(str(run_id)) + if tool_run is None or tool_run.run_type != "tool": + raise TracerException(f"No tool Run found to be traced for {run_id}") + + tool_run.outputs = {"output": output} + tool_run.end_time = datetime.utcnow() + tool_run.events.append({"name": "end", "time": tool_run.end_time}) + self._end_trace(tool_run) + self._on_tool_end(tool_run) + return tool_run + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> Run: + """Handle an error for a tool run.""" + if not run_id: + raise TracerException("No run_id provided for on_tool_error callback.") + tool_run = self.run_map.get(str(run_id)) + if tool_run is None or tool_run.run_type != "tool": + raise TracerException(f"No tool Run found to be traced for {run_id}") + + tool_run.error = repr(error) + tool_run.end_time = datetime.utcnow() + tool_run.events.append({"name": "error", "time": tool_run.end_time}) + self._end_trace(tool_run) + self._on_tool_error(tool_run) + return tool_run + + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + """Run when Retriever starts running.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + retrieval_run = Run( + id=run_id, + name=name or "Retriever", + parent_run_id=parent_run_id, + serialized=serialized, + inputs={"query": query}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + tags=tags, + child_runs=[], + run_type="retriever", + ) + self._start_trace(retrieval_run) + self._on_retriever_start(retrieval_run) + return retrieval_run + + def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> Run: + """Run when Retriever errors.""" + if not run_id: + raise TracerException("No run_id provided for on_retriever_error callback.") + retrieval_run = self.run_map.get(str(run_id)) + if retrieval_run is None or retrieval_run.run_type != "retriever": + raise TracerException(f"No retriever Run found to be traced for {run_id}") + + retrieval_run.error = repr(error) + retrieval_run.end_time = datetime.utcnow() + retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time}) + self._end_trace(retrieval_run) + self._on_retriever_error(retrieval_run) + return retrieval_run + + def on_retriever_end( + self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any + ) -> Run: + """Run when Retriever ends running.""" + if not run_id: + raise TracerException("No run_id provided for on_retriever_end callback.") + retrieval_run = self.run_map.get(str(run_id)) + if retrieval_run is None or retrieval_run.run_type != "retriever": + raise TracerException(f"No retriever Run found to be traced for {run_id}") + retrieval_run.outputs = {"documents": documents} + retrieval_run.end_time = datetime.utcnow() + retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time}) + self._end_trace(retrieval_run) + self._on_retriever_end(retrieval_run) + return retrieval_run + + def __deepcopy__(self, memo: dict) -> BaseTracer: + """Deepcopy the tracer.""" + return self + + def __copy__(self) -> BaseTracer: + """Copy the tracer.""" + return self + + def _on_run_create(self, run: Run) -> None: + """Process a run upon creation.""" + + def _on_run_update(self, run: Run) -> None: + """Process a run upon update.""" + + def _on_llm_start(self, run: Run) -> None: + """Process the LLM Run upon start.""" + + def _on_llm_new_token( + self, + run: Run, + token: str, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + ) -> None: + """Process new LLM token.""" + + def _on_llm_end(self, run: Run) -> None: + """Process the LLM Run.""" + + def _on_llm_error(self, run: Run) -> None: + """Process the LLM Run upon error.""" + + def _on_chain_start(self, run: Run) -> None: + """Process the Chain Run upon start.""" + + def _on_chain_end(self, run: Run) -> None: + """Process the Chain Run.""" + + def _on_chain_error(self, run: Run) -> None: + """Process the Chain Run upon error.""" + + def _on_tool_start(self, run: Run) -> None: + """Process the Tool Run upon start.""" + + def _on_tool_end(self, run: Run) -> None: + """Process the Tool Run.""" + + def _on_tool_error(self, run: Run) -> None: + """Process the Tool Run upon error.""" + + def _on_chat_model_start(self, run: Run) -> None: + """Process the Chat Model Run upon start.""" + + def _on_retriever_start(self, run: Run) -> None: + """Process the Retriever Run upon start.""" + + def _on_retriever_end(self, run: Run) -> None: + """Process the Retriever Run.""" + + def _on_retriever_error(self, run: Run) -> None: + """Process the Retriever Run upon error.""" diff --git a/libs/langchain/langchain/schema/callbacks/tracers/evaluation.py b/libs/langchain/langchain/schema/callbacks/tracers/evaluation.py new file mode 100644 index 00000000000..eac08e6c0e2 --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/evaluation.py @@ -0,0 +1,222 @@ +"""A tracer that runs evaluators over completed runs.""" +from __future__ import annotations + +import logging +import threading +import weakref +from concurrent.futures import Future, ThreadPoolExecutor, wait +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from uuid import UUID + +import langsmith +from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults + +from langchain.schema.callbacks import manager +from langchain.schema.callbacks.tracers import langchain as langchain_tracer +from langchain.schema.callbacks.tracers.base import BaseTracer +from langchain.schema.callbacks.tracers.langchain import _get_executor +from langchain.schema.callbacks.tracers.schemas import Run + +logger = logging.getLogger(__name__) + +_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet() + + +def wait_for_all_evaluators() -> None: + """Wait for all tracers to finish.""" + global _TRACERS + for tracer in list(_TRACERS): + if tracer is not None: + tracer.wait_for_futures() + + +class EvaluatorCallbackHandler(BaseTracer): + """A tracer that runs a run evaluator whenever a run is persisted. + + Parameters + ---------- + evaluators : Sequence[RunEvaluator] + The run evaluators to apply to all top level runs. + client : LangSmith Client, optional + The LangSmith client instance to use for evaluating the runs. + If not specified, a new instance will be created. + example_id : Union[UUID, str], optional + The example ID to be associated with the runs. + project_name : str, optional + The LangSmith project name to be organize eval chain runs under. + + Attributes + ---------- + example_id : Union[UUID, None] + The example ID associated with the runs. + client : Client + The LangSmith client instance used for evaluating the runs. + evaluators : Sequence[RunEvaluator] + The sequence of run evaluators to be executed. + executor : ThreadPoolExecutor + The thread pool executor used for running the evaluators. + futures : Set[Future] + The set of futures representing the running evaluators. + skip_unfinished : bool + Whether to skip runs that are not finished or raised + an error. + project_name : Optional[str] + The LangSmith project name to be organize eval chain runs under. + """ + + name = "evaluator_callback_handler" + + def __init__( + self, + evaluators: Sequence[langsmith.RunEvaluator], + client: Optional[langsmith.Client] = None, + example_id: Optional[Union[UUID, str]] = None, + skip_unfinished: bool = True, + project_name: Optional[str] = "evaluators", + max_concurrency: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.example_id = ( + UUID(example_id) if isinstance(example_id, str) else example_id + ) + self.client = client or langchain_tracer.get_client() + self.evaluators = evaluators + if max_concurrency is None: + self.executor: Optional[ThreadPoolExecutor] = _get_executor() + elif max_concurrency > 0: + self.executor = ThreadPoolExecutor(max_workers=max_concurrency) + weakref.finalize( + self, + lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True), + ) + else: + self.executor = None + self.futures: weakref.WeakSet[Future] = weakref.WeakSet() + self.skip_unfinished = skip_unfinished + self.project_name = project_name + self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {} + self.lock = threading.Lock() + global _TRACERS + _TRACERS.add(self) + + def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None: + """Evaluate the run in the project. + + Parameters + ---------- + run : Run + The run to be evaluated. + evaluator : RunEvaluator + The evaluator to use for evaluating the run. + + """ + try: + if self.project_name is None: + eval_result = self.client.evaluate_run(run, evaluator) + eval_results = [eval_result] + with manager.tracing_v2_enabled( + project_name=self.project_name, tags=["eval"], client=self.client + ) as cb: + reference_example = ( + self.client.read_example(run.reference_example_id) + if run.reference_example_id + else None + ) + evaluation_result = evaluator.evaluate_run( + run, + example=reference_example, + ) + eval_results = self._log_evaluation_feedback( + evaluation_result, + run, + source_run_id=cb.latest_run.id if cb.latest_run else None, + ) + except Exception as e: + logger.error( + f"Error evaluating run {run.id} with " + f"{evaluator.__class__.__name__}: {repr(e)}", + exc_info=True, + ) + raise e + example_id = str(run.reference_example_id) + with self.lock: + for res in eval_results: + run_id = ( + str(getattr(res, "target_run_id")) + if hasattr(res, "target_run_id") + else str(run.id) + ) + self.logged_eval_results.setdefault((run_id, example_id), []).append( + res + ) + + def _select_eval_results( + self, + results: Union[EvaluationResult, EvaluationResults], + ) -> List[EvaluationResult]: + if isinstance(results, EvaluationResult): + results_ = [results] + elif isinstance(results, dict) and "results" in results: + results_ = cast(List[EvaluationResult], results["results"]) + else: + raise TypeError( + f"Invalid evaluation result type {type(results)}." + " Expected EvaluationResult or EvaluationResults." + ) + return results_ + + def _log_evaluation_feedback( + self, + evaluator_response: Union[EvaluationResult, EvaluationResults], + run: Run, + source_run_id: Optional[UUID] = None, + ) -> List[EvaluationResult]: + results = self._select_eval_results(evaluator_response) + for res in results: + source_info_: Dict[str, Any] = {} + if res.evaluator_info: + source_info_ = {**res.evaluator_info, **source_info_} + run_id_ = ( + getattr(res, "target_run_id") + if hasattr(res, "target_run_id") and res.target_run_id is not None + else run.id + ) + self.client.create_feedback( + run_id_, + res.key, + score=res.score, + value=res.value, + comment=res.comment, + correction=res.correction, + source_info=source_info_, + source_run_id=res.source_run_id or source_run_id, + feedback_source_type=langsmith.schemas.FeedbackSourceType.MODEL, + ) + return results + + def _persist_run(self, run: Run) -> None: + """Run the evaluator on the run. + + Parameters + ---------- + run : Run + The run to be evaluated. + + """ + if self.skip_unfinished and not run.outputs: + logger.debug(f"Skipping unfinished run {run.id}") + return + run_ = run.copy() + run_.reference_example_id = self.example_id + for evaluator in self.evaluators: + if self.executor is None: + self._evaluate_in_project(run_, evaluator) + else: + self.futures.add( + self.executor.submit(self._evaluate_in_project, run_, evaluator) + ) + + def wait_for_futures(self) -> None: + """Wait for all futures to complete.""" + wait(self.futures) diff --git a/libs/langchain/langchain/schema/callbacks/tracers/langchain.py b/libs/langchain/langchain/schema/callbacks/tracers/langchain.py new file mode 100644 index 00000000000..c109c8b2a9d --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/langchain.py @@ -0,0 +1,262 @@ +"""A Tracer implementation that records to LangChain endpoint.""" +from __future__ import annotations + +import logging +import weakref +from concurrent.futures import Future, ThreadPoolExecutor, wait +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Union +from uuid import UUID + +from langsmith import Client +from langsmith import utils as ls_utils +from tenacity import ( + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from langchain.env import get_runtime_environment +from langchain.load.dump import dumpd +from langchain.schema.callbacks.tracers.base import BaseTracer +from langchain.schema.callbacks.tracers.schemas import Run +from langchain.schema.messages import BaseMessage + +logger = logging.getLogger(__name__) +_LOGGED = set() +_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet() +_CLIENT: Optional[Client] = None +_EXECUTOR: Optional[ThreadPoolExecutor] = None + + +def log_error_once(method: str, exception: Exception) -> None: + """Log an error once.""" + global _LOGGED + if (method, type(exception)) in _LOGGED: + return + _LOGGED.add((method, type(exception))) + logger.error(exception) + + +def wait_for_all_tracers() -> None: + """Wait for all tracers to finish.""" + global _TRACERS + for tracer in list(_TRACERS): + if tracer is not None: + tracer.wait_for_futures() + + +def get_client() -> Client: + """Get the client.""" + global _CLIENT + if _CLIENT is None: + _CLIENT = Client() + return _CLIENT + + +def _get_executor() -> ThreadPoolExecutor: + """Get the executor.""" + global _EXECUTOR + if _EXECUTOR is None: + _EXECUTOR = ThreadPoolExecutor() + return _EXECUTOR + + +def _copy(run: Run) -> Run: + """Copy a run.""" + try: + return run.copy(deep=True) + except TypeError: + # Fallback in case the object contains a lock or other + # non-pickleable object + return run.copy() + + +class LangChainTracer(BaseTracer): + """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" + + def __init__( + self, + example_id: Optional[Union[UUID, str]] = None, + project_name: Optional[str] = None, + client: Optional[Client] = None, + tags: Optional[List[str]] = None, + use_threading: bool = True, + **kwargs: Any, + ) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self.example_id = ( + UUID(example_id) if isinstance(example_id, str) else example_id + ) + self.project_name = project_name or ls_utils.get_tracer_project() + self.client = client or get_client() + self._futures: weakref.WeakSet[Future] = weakref.WeakSet() + self.tags = tags or [] + self.executor = _get_executor() if use_threading else None + self.latest_run: Optional[Run] = None + global _TRACERS + _TRACERS.add(self) + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Start a trace for an LLM run.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + chat_model_run = Run( + id=run_id, + parent_run_id=parent_run_id, + serialized=serialized, + inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + run_type="llm", + tags=tags, + name=name, + ) + self._start_trace(chat_model_run) + self._on_chat_model_start(chat_model_run) + + def _persist_run(self, run: Run) -> None: + run_ = run.copy() + run_.reference_example_id = self.example_id + self.latest_run = run_ + + def get_run_url(self) -> str: + """Get the LangSmith root run URL""" + if not self.latest_run: + raise ValueError("No traced run found.") + # If this is the first run in a project, the project may not yet be created. + # This method is only really useful for debugging flows, so we will assume + # there is some tolerace for latency. + for attempt in Retrying( + stop=stop_after_attempt(5), + wait=wait_exponential_jitter(), + retry=retry_if_exception_type(ls_utils.LangSmithError), + ): + with attempt: + return self.client.get_run_url( + run=self.latest_run, project_name=self.project_name + ) + raise ValueError("Failed to get run URL.") + + def _get_tags(self, run: Run) -> List[str]: + """Get combined tags for a run.""" + tags = set(run.tags or []) + tags.update(self.tags or []) + return list(tags) + + def _persist_run_single(self, run: Run) -> None: + """Persist a run.""" + run_dict = run.dict(exclude={"child_runs"}) + run_dict["tags"] = self._get_tags(run) + extra = run_dict.get("extra", {}) + extra["runtime"] = get_runtime_environment() + run_dict["extra"] = extra + try: + self.client.create_run(**run_dict, project_name=self.project_name) + except Exception as e: + # Errors are swallowed by the thread executor so we need to log them here + log_error_once("post", e) + raise + + def _update_run_single(self, run: Run) -> None: + """Update a run.""" + try: + run_dict = run.dict() + run_dict["tags"] = self._get_tags(run) + self.client.update_run(run.id, **run_dict) + except Exception as e: + # Errors are swallowed by the thread executor so we need to log them here + log_error_once("patch", e) + raise + + def _submit(self, function: Callable[[Run], None], run: Run) -> None: + """Submit a function to the executor.""" + if self.executor is None: + function(run) + else: + self._futures.add(self.executor.submit(function, run)) + + def _on_llm_start(self, run: Run) -> None: + """Persist an LLM run.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_chat_model_start(self, run: Run) -> None: + """Persist an LLM run.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_llm_end(self, run: Run) -> None: + """Process the LLM Run.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_llm_error(self, run: Run) -> None: + """Process the LLM Run upon error.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_chain_start(self, run: Run) -> None: + """Process the Chain Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_chain_end(self, run: Run) -> None: + """Process the Chain Run.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_chain_error(self, run: Run) -> None: + """Process the Chain Run upon error.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_tool_start(self, run: Run) -> None: + """Process the Tool Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_tool_end(self, run: Run) -> None: + """Process the Tool Run.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_tool_error(self, run: Run) -> None: + """Process the Tool Run upon error.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_retriever_start(self, run: Run) -> None: + """Process the Retriever Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_retriever_end(self, run: Run) -> None: + """Process the Retriever Run.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_retriever_error(self, run: Run) -> None: + """Process the Retriever Run upon error.""" + self._submit(self._update_run_single, _copy(run)) + + def wait_for_futures(self) -> None: + """Wait for the given futures to complete.""" + wait(self._futures) diff --git a/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py b/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py new file mode 100644 index 00000000000..957ae85875d --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, Optional, Union + +import requests + +from langchain.schema.callbacks.tracers.base import BaseTracer +from langchain.schema.callbacks.tracers.schemas import ( + ChainRun, + LLMRun, + Run, + ToolRun, + TracerSession, + TracerSessionV1, + TracerSessionV1Base, +) +from langchain.schema.messages import get_buffer_string +from langchain.utils import raise_for_status_with_text + +logger = logging.getLogger(__name__) + + +def get_headers() -> Dict[str, Any]: + """Get the headers for the LangChain API.""" + headers: Dict[str, Any] = {"Content-Type": "application/json"} + if os.getenv("LANGCHAIN_API_KEY"): + headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + return headers + + +def _get_endpoint() -> str: + return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") + + +class LangChainTracerV1(BaseTracer): + """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self.session: Optional[TracerSessionV1] = None + self._endpoint = _get_endpoint() + self._headers = get_headers() + + def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]: + session = self.session or self.load_default_session() + if not isinstance(session, TracerSessionV1): + raise ValueError( + "LangChainTracerV1 is not compatible with" + f" session of type {type(session)}" + ) + + if run.run_type == "llm": + if "prompts" in run.inputs: + prompts = run.inputs["prompts"] + elif "messages" in run.inputs: + prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]] + else: + raise ValueError("No prompts found in LLM run inputs") + return LLMRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + extra=run.extra, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + error=run.error, + prompts=prompts, + response=run.outputs if run.outputs else None, + ) + if run.run_type == "chain": + child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] + return ChainRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + inputs=run.inputs, + outputs=run.outputs, + error=run.error, + extra=run.extra, + child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], + child_chain_runs=[ + run for run in child_runs if isinstance(run, ChainRun) + ], + child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], + ) + if run.run_type == "tool": + child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] + return ToolRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + action=str(run.serialized), + tool_input=run.inputs.get("input", ""), + output=None if run.outputs is None else run.outputs.get("output"), + error=run.error, + extra=run.extra, + child_chain_runs=[ + run for run in child_runs if isinstance(run, ChainRun) + ], + child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], + child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], + ) + raise ValueError(f"Unknown run type: {run.run_type}") + + def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + if isinstance(run, Run): + v1_run = self._convert_to_v1_run(run) + else: + v1_run = run + if isinstance(v1_run, LLMRun): + endpoint = f"{self._endpoint}/llm-runs" + elif isinstance(v1_run, ChainRun): + endpoint = f"{self._endpoint}/chain-runs" + else: + endpoint = f"{self._endpoint}/tool-runs" + + try: + response = requests.post( + endpoint, + data=v1_run.json(), + headers=self._headers, + ) + raise_for_status_with_text(response) + except Exception as e: + logger.warning(f"Failed to persist run: {e}") + + def _persist_session( + self, session_create: TracerSessionV1Base + ) -> Union[TracerSessionV1, TracerSession]: + """Persist a session.""" + try: + r = requests.post( + f"{self._endpoint}/sessions", + data=session_create.json(), + headers=self._headers, + ) + session = TracerSessionV1(id=r.json()["id"], **session_create.dict()) + except Exception as e: + logger.warning(f"Failed to create session, using default session: {e}") + session = TracerSessionV1(id=1, **session_create.dict()) + return session + + def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1: + """Load a session from the tracer.""" + try: + url = f"{self._endpoint}/sessions" + if session_name: + url += f"?name={session_name}" + r = requests.get(url, headers=self._headers) + + tracer_session = TracerSessionV1(**r.json()[0]) + except Exception as e: + session_type = "default" if not session_name else session_name + logger.warning( + f"Failed to load {session_type} session, using empty session: {e}" + ) + tracer_session = TracerSessionV1(id=1) + + self.session = tracer_session + return tracer_session + + def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]: + """Load a session with the given name from the tracer.""" + return self._load_session(session_name) + + def load_default_session(self) -> Union[TracerSessionV1, TracerSession]: + """Load the default tracing session and set it as the Tracer's session.""" + return self._load_session("default") diff --git a/libs/langchain/langchain/schema/callbacks/tracers/log_stream.py b/libs/langchain/langchain/schema/callbacks/tracers/log_stream.py new file mode 100644 index 00000000000..6b2acd3cbd2 --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/log_stream.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import math +import threading +from collections import defaultdict +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + Sequence, + TypedDict, + Union, +) +from uuid import UUID + +import jsonpatch +from anyio import create_memory_object_stream + +from langchain.load.load import load +from langchain.schema.callbacks.tracers.base import BaseTracer +from langchain.schema.callbacks.tracers.schemas import Run +from langchain.schema.output import ChatGenerationChunk, GenerationChunk + + +class LogEntry(TypedDict): + """A single entry in the run log.""" + + id: str + """ID of the sub-run.""" + name: str + """Name of the object being run.""" + type: str + """Type of the object being run, eg. prompt, chain, llm, etc.""" + tags: List[str] + """List of tags for the run.""" + metadata: Dict[str, Any] + """Key-value pairs of metadata for the run.""" + start_time: str + """ISO-8601 timestamp of when the run started.""" + + streamed_output_str: List[str] + """List of LLM tokens streamed by this run, if applicable.""" + final_output: Optional[Any] + """Final output of this run. + Only available after the run has finished successfully.""" + end_time: Optional[str] + """ISO-8601 timestamp of when the run ended. + Only available after the run has finished.""" + + +class RunState(TypedDict): + """State of the run.""" + + id: str + """ID of the run.""" + streamed_output: List[Any] + """List of output chunks streamed by Runnable.stream()""" + final_output: Optional[Any] + """Final output of the run, usually the result of aggregating (`+`) streamed_output. + Only available after the run has finished successfully.""" + + logs: Dict[str, LogEntry] + """Map of run names to sub-runs. If filters were supplied, this list will + contain only the runs that matched the filters.""" + + +class RunLogPatch: + """A patch to the run log.""" + + ops: List[Dict[str, Any]] + """List of jsonpatch operations, which describe how to create the run state + from an empty dict. This is the minimal representation of the log, designed to + be serialized as JSON and sent over the wire to reconstruct the log on the other + side. Reconstruction of the state can be done with any jsonpatch-compliant library, + see https://jsonpatch.com for more information.""" + + def __init__(self, *ops: Dict[str, Any]) -> None: + self.ops = list(ops) + + def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: + if type(other) == RunLogPatch: + ops = self.ops + other.ops + state = jsonpatch.apply_patch(None, ops) + return RunLog(*ops, state=state) + + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + def __repr__(self) -> str: + from pprint import pformat + + # 1:-1 to get rid of the [] around the list + return f"RunLogPatch({pformat(self.ops)[1:-1]})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, RunLogPatch) and self.ops == other.ops + + +class RunLog(RunLogPatch): + """A run log.""" + + state: RunState + """Current state of the log, obtained from applying all ops in sequence.""" + + def __init__(self, *ops: Dict[str, Any], state: RunState) -> None: + super().__init__(*ops) + self.state = state + + def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: + if type(other) == RunLogPatch: + ops = self.ops + other.ops + state = jsonpatch.apply_patch(self.state, other.ops) + return RunLog(*ops, state=state) + + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + def __repr__(self) -> str: + from pprint import pformat + + return f"RunLog({pformat(self.state)})" + + +class LogStreamCallbackHandler(BaseTracer): + """A tracer that streams run logs to a stream.""" + + def __init__( + self, + *, + auto_close: bool = True, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + ) -> None: + super().__init__() + + self.auto_close = auto_close + self.include_names = include_names + self.include_types = include_types + self.include_tags = include_tags + self.exclude_names = exclude_names + self.exclude_types = exclude_types + self.exclude_tags = exclude_tags + + send_stream, receive_stream = create_memory_object_stream( + math.inf, item_type=RunLogPatch + ) + self.lock = threading.Lock() + self.send_stream = send_stream + self.receive_stream = receive_stream + self._key_map_by_run_id: Dict[UUID, str] = {} + self._counter_map_by_name: Dict[str, int] = defaultdict(int) + self.root_id: Optional[UUID] = None + + def __aiter__(self) -> AsyncIterator[RunLogPatch]: + return self.receive_stream.__aiter__() + + def include_run(self, run: Run) -> bool: + if run.id == self.root_id: + return False + + run_tags = run.tags or [] + + if ( + self.include_names is None + and self.include_types is None + and self.include_tags is None + ): + include = True + else: + include = False + + if self.include_names is not None: + include = include or run.name in self.include_names + if self.include_types is not None: + include = include or run.run_type in self.include_types + if self.include_tags is not None: + include = include or any(tag in self.include_tags for tag in run_tags) + + if self.exclude_names is not None: + include = include and run.name not in self.exclude_names + if self.exclude_types is not None: + include = include and run.run_type not in self.exclude_types + if self.exclude_tags is not None: + include = include and all(tag not in self.exclude_tags for tag in run_tags) + + return include + + def _persist_run(self, run: Run) -> None: + # This is a legacy method only called once for an entire run tree + # therefore not useful here + pass + + def _on_run_create(self, run: Run) -> None: + """Start a run.""" + if self.root_id is None: + self.root_id = run.id + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "replace", + "path": "", + "value": RunState( + id=str(run.id), + streamed_output=[], + final_output=None, + logs={}, + ), + } + ) + ) + + if not self.include_run(run): + return + + # Determine previous index, increment by 1 + with self.lock: + self._counter_map_by_name[run.name] += 1 + count = self._counter_map_by_name[run.name] + self._key_map_by_run_id[run.id] = ( + run.name if count == 1 else f"{run.name}:{count}" + ) + + # Add the run to the stream + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{self._key_map_by_run_id[run.id]}", + "value": LogEntry( + id=str(run.id), + name=run.name, + type=run.run_type, + tags=run.tags or [], + metadata=(run.extra or {}).get("metadata", {}), + start_time=run.start_time.isoformat(timespec="milliseconds"), + streamed_output_str=[], + final_output=None, + end_time=None, + ), + } + ) + ) + + def _on_run_update(self, run: Run) -> None: + """Finish a run.""" + try: + index = self._key_map_by_run_id.get(run.id) + + if index is None: + return + + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{index}/final_output", + # to undo the dumpd done by some runnables / tracer / etc + "value": load(run.outputs), + }, + { + "op": "add", + "path": f"/logs/{index}/end_time", + "value": run.end_time.isoformat(timespec="milliseconds") + if run.end_time is not None + else None, + }, + ) + ) + finally: + if run.id == self.root_id: + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "replace", + "path": "/final_output", + "value": load(run.outputs), + } + ) + ) + if self.auto_close: + self.send_stream.close() + + def _on_llm_new_token( + self, + run: Run, + token: str, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + ) -> None: + """Process new LLM token.""" + index = self._key_map_by_run_id.get(run.id) + + if index is None: + return + + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{index}/streamed_output_str/-", + "value": token, + } + ) + ) diff --git a/libs/langchain/langchain/schema/callbacks/tracers/root_listeners.py b/libs/langchain/langchain/schema/callbacks/tracers/root_listeners.py new file mode 100644 index 00000000000..5a134e5946c --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/root_listeners.py @@ -0,0 +1,54 @@ +from typing import Callable, Optional, Union +from uuid import UUID + +from langchain.schema.callbacks.tracers.base import BaseTracer +from langchain.schema.callbacks.tracers.schemas import Run +from langchain.schema.runnable.config import ( + RunnableConfig, + call_func_with_variable_args, +) + +Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] + + +class RootListenersTracer(BaseTracer): + def __init__( + self, + *, + config: RunnableConfig, + on_start: Optional[Listener], + on_end: Optional[Listener], + on_error: Optional[Listener], + ) -> None: + super().__init__() + + self.config = config + self._arg_on_start = on_start + self._arg_on_end = on_end + self._arg_on_error = on_error + self.root_id: Optional[UUID] = None + + def _persist_run(self, run: Run) -> None: + # This is a legacy method only called once for an entire run tree + # therefore not useful here + pass + + def _on_run_create(self, run: Run) -> None: + if self.root_id is not None: + return + + self.root_id = run.id + + if self._arg_on_start is not None: + call_func_with_variable_args(self._arg_on_start, run, self.config) + + def _on_run_update(self, run: Run) -> None: + if run.id != self.root_id: + return + + if run.error is None: + if self._arg_on_end is not None: + call_func_with_variable_args(self._arg_on_end, run, self.config) + else: + if self._arg_on_error is not None: + call_func_with_variable_args(self._arg_on_error, run, self.config) diff --git a/libs/langchain/langchain/schema/callbacks/tracers/run_collector.py b/libs/langchain/langchain/schema/callbacks/tracers/run_collector.py new file mode 100644 index 00000000000..8087121a13d --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/run_collector.py @@ -0,0 +1,52 @@ +"""A tracer that collects all nested runs in a list.""" + +from typing import Any, List, Optional, Union +from uuid import UUID + +from langchain.schema.callbacks.tracers.base import BaseTracer +from langchain.schema.callbacks.tracers.schemas import Run + + +class RunCollectorCallbackHandler(BaseTracer): + """ + A tracer that collects all nested runs in a list. + + This tracer is useful for inspection and evaluation purposes. + + Parameters + ---------- + example_id : Optional[Union[UUID, str]], default=None + The ID of the example being traced. It can be either a UUID or a string. + """ + + name: str = "run-collector_callback_handler" + + def __init__( + self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any + ) -> None: + """ + Initialize the RunCollectorCallbackHandler. + + Parameters + ---------- + example_id : Optional[Union[UUID, str]], default=None + The ID of the example being traced. It can be either a UUID or a string. + """ + super().__init__(**kwargs) + self.example_id = ( + UUID(example_id) if isinstance(example_id, str) else example_id + ) + self.traced_runs: List[Run] = [] + + def _persist_run(self, run: Run) -> None: + """ + Persist a run by adding it to the traced_runs list. + + Parameters + ---------- + run : Run + The run to be persisted. + """ + run_ = run.copy() + run_.reference_example_id = self.example_id + self.traced_runs.append(run_) diff --git a/libs/langchain/langchain/schema/callbacks/tracers/schemas.py b/libs/langchain/langchain/schema/callbacks/tracers/schemas.py new file mode 100644 index 00000000000..4db455be2ea --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/schemas.py @@ -0,0 +1,140 @@ +"""Schemas for tracers.""" +from __future__ import annotations + +import datetime +import warnings +from typing import Any, Dict, List, Optional, Type +from uuid import UUID + +from langsmith.schemas import RunBase as BaseRunV2 +from langsmith.schemas import RunTypeEnum as RunTypeEnumDep + +from langchain.pydantic_v1 import BaseModel, Field, root_validator +from langchain.schema import LLMResult + + +def RunTypeEnum() -> Type[RunTypeEnumDep]: + """RunTypeEnum.""" + warnings.warn( + "RunTypeEnum is deprecated. Please directly use a string instead" + " (e.g. 'llm', 'chain', 'tool').", + DeprecationWarning, + ) + return RunTypeEnumDep + + +class TracerSessionV1Base(BaseModel): + """Base class for TracerSessionV1.""" + + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + name: Optional[str] = None + extra: Optional[Dict[str, Any]] = None + + +class TracerSessionV1Create(TracerSessionV1Base): + """Create class for TracerSessionV1.""" + + +class TracerSessionV1(TracerSessionV1Base): + """TracerSessionV1 schema.""" + + id: int + + +class TracerSessionBase(TracerSessionV1Base): + """Base class for TracerSession.""" + + tenant_id: UUID + + +class TracerSession(TracerSessionBase): + """TracerSessionV1 schema for the V2 API.""" + + id: UUID + + +class BaseRun(BaseModel): + """Base class for Run.""" + + uuid: str + parent_uuid: Optional[str] = None + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + extra: Optional[Dict[str, Any]] = None + execution_order: int + child_execution_order: int + serialized: Dict[str, Any] + session_id: int + error: Optional[str] = None + + +class LLMRun(BaseRun): + """Class for LLMRun.""" + + prompts: List[str] + response: Optional[LLMResult] = None + + +class ChainRun(BaseRun): + """Class for ChainRun.""" + + inputs: Dict[str, Any] + outputs: Optional[Dict[str, Any]] = None + child_llm_runs: List[LLMRun] = Field(default_factory=list) + child_chain_runs: List[ChainRun] = Field(default_factory=list) + child_tool_runs: List[ToolRun] = Field(default_factory=list) + + +class ToolRun(BaseRun): + """Class for ToolRun.""" + + tool_input: str + output: Optional[str] = None + action: str + child_llm_runs: List[LLMRun] = Field(default_factory=list) + child_chain_runs: List[ChainRun] = Field(default_factory=list) + child_tool_runs: List[ToolRun] = Field(default_factory=list) + + +# Begin V2 API Schemas + + +class Run(BaseRunV2): + """Run schema for the V2 API in the Tracer.""" + + execution_order: int + child_execution_order: int + child_runs: List[Run] = Field(default_factory=list) + tags: Optional[List[str]] = Field(default_factory=list) + events: List[Dict[str, Any]] = Field(default_factory=list) + + @root_validator(pre=True) + def assign_name(cls, values: dict) -> dict: + """Assign name to the run.""" + if values.get("name") is None: + if "name" in values["serialized"]: + values["name"] = values["serialized"]["name"] + elif "id" in values["serialized"]: + values["name"] = values["serialized"]["id"][-1] + if values.get("events") is None: + values["events"] = [] + return values + + +ChainRun.update_forward_refs() +ToolRun.update_forward_refs() +Run.update_forward_refs() + +__all__ = [ + "BaseRun", + "ChainRun", + "LLMRun", + "Run", + "RunTypeEnum", + "ToolRun", + "TracerSession", + "TracerSessionBase", + "TracerSessionV1", + "TracerSessionV1Base", + "TracerSessionV1Create", +] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/stdout.py b/libs/langchain/langchain/schema/callbacks/tracers/stdout.py new file mode 100644 index 00000000000..564f419a5bf --- /dev/null +++ b/libs/langchain/langchain/schema/callbacks/tracers/stdout.py @@ -0,0 +1,178 @@ +import json +from typing import Any, Callable, List + +from langchain.schema.callbacks.tracers.base import BaseTracer +from langchain.schema.callbacks.tracers.schemas import Run +from langchain.utils.input import get_bolded_text, get_colored_text + + +def try_json_stringify(obj: Any, fallback: str) -> str: + """ + Try to stringify an object to JSON. + Args: + obj: Object to stringify. + fallback: Fallback string to return if the object cannot be stringified. + + Returns: + A JSON string if the object can be stringified, otherwise the fallback string. + + """ + try: + return json.dumps(obj, indent=2, ensure_ascii=False) + except Exception: + return fallback + + +def elapsed(run: Any) -> str: + """Get the elapsed time of a run. + + Args: + run: any object with a start_time and end_time attribute. + + Returns: + A string with the elapsed time in seconds or + milliseconds if time is less than a second. + + """ + elapsed_time = run.end_time - run.start_time + milliseconds = elapsed_time.total_seconds() * 1000 + if milliseconds < 1000: + return f"{milliseconds:.0f}ms" + return f"{(milliseconds / 1000):.2f}s" + + +class FunctionCallbackHandler(BaseTracer): + """Tracer that calls a function with a single str parameter.""" + + name: str = "function_callback_handler" + + def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None: + super().__init__(**kwargs) + self.function_callback = function + + def _persist_run(self, run: Run) -> None: + pass + + def get_parents(self, run: Run) -> List[Run]: + parents = [] + current_run = run + while current_run.parent_run_id: + parent = self.run_map.get(str(current_run.parent_run_id)) + if parent: + parents.append(parent) + current_run = parent + else: + break + return parents + + def get_breadcrumbs(self, run: Run) -> str: + parents = self.get_parents(run)[::-1] + string = " > ".join( + f"{parent.execution_order}:{parent.run_type}:{parent.name}" + if i != len(parents) - 1 + else f"{parent.execution_order}:{parent.run_type}:{parent.name}" + for i, parent in enumerate(parents + [run]) + ) + return string + + # logging methods + def _on_chain_start(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + run_type = run.run_type.capitalize() + self.function_callback( + f"{get_colored_text('[chain/start]', color='green')} " + + get_bolded_text(f"[{crumbs}] Entering {run_type} run with input:\n") + + f"{try_json_stringify(run.inputs, '[inputs]')}" + ) + + def _on_chain_end(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + run_type = run.run_type.capitalize() + self.function_callback( + f"{get_colored_text('[chain/end]', color='blue')} " + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] Exiting {run_type} run with output:\n" + ) + + f"{try_json_stringify(run.outputs, '[outputs]')}" + ) + + def _on_chain_error(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + run_type = run.run_type.capitalize() + self.function_callback( + f"{get_colored_text('[chain/error]', color='red')} " + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] {run_type} run errored with error:\n" + ) + + f"{try_json_stringify(run.error, '[error]')}" + ) + + def _on_llm_start(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + inputs = ( + {"prompts": [p.strip() for p in run.inputs["prompts"]]} + if "prompts" in run.inputs + else run.inputs + ) + self.function_callback( + f"{get_colored_text('[llm/start]', color='green')} " + + get_bolded_text(f"[{crumbs}] Entering LLM run with input:\n") + + f"{try_json_stringify(inputs, '[inputs]')}" + ) + + def _on_llm_end(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + self.function_callback( + f"{get_colored_text('[llm/end]', color='blue')} " + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] Exiting LLM run with output:\n" + ) + + f"{try_json_stringify(run.outputs, '[response]')}" + ) + + def _on_llm_error(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + self.function_callback( + f"{get_colored_text('[llm/error]', color='red')} " + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] LLM run errored with error:\n" + ) + + f"{try_json_stringify(run.error, '[error]')}" + ) + + def _on_tool_start(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + self.function_callback( + f'{get_colored_text("[tool/start]", color="green")} ' + + get_bolded_text(f"[{crumbs}] Entering Tool run with input:\n") + + f'"{run.inputs["input"].strip()}"' + ) + + def _on_tool_end(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + if run.outputs: + self.function_callback( + f'{get_colored_text("[tool/end]", color="blue")} ' + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] Exiting Tool run with output:\n" + ) + + f'"{run.outputs["output"].strip()}"' + ) + + def _on_tool_error(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + self.function_callback( + f"{get_colored_text('[tool/error]', color='red')} " + + get_bolded_text(f"[{crumbs}] [{elapsed(run)}] ") + + f"Tool run errored with error:\n" + f"{run.error}" + ) + + +class ConsoleCallbackHandler(FunctionCallbackHandler): + """Tracer that prints to the console.""" + + name: str = "console_callback_handler" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(function=print, **kwargs) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 0a7e2449ad2..54bbdca4a93 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -32,12 +32,12 @@ from typing import ( from typing_extensions import Literal, get_args if TYPE_CHECKING: - from langchain.callbacks.manager import ( + from langchain.schema.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) - from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch - from langchain.callbacks.tracers.root_listeners import Listener + from langchain.schema.callbacks.tracers.log_stream import RunLog, RunLogPatch + from langchain.schema.callbacks.tracers.root_listeners import Listener from langchain.schema.runnable.fallbacks import ( RunnableWithFallbacks as RunnableWithFallbacksT, ) diff --git a/libs/langchain/scripts/check_imports.sh b/libs/langchain/scripts/check_imports.sh index 2440eb8b9e9..80c38b30f58 100755 --- a/libs/langchain/scripts/check_imports.sh +++ b/libs/langchain/scripts/check_imports.sh @@ -10,7 +10,7 @@ git grep '^from langchain import' langchain | grep -vE 'from langchain import (_ git grep '^from langchain ' langchain/pydantic_v1 | grep -vE 'from langchain.(pydantic_v1)' && errors=$((errors+1)) git grep '^from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1|load)' && errors=$((errors+1)) git grep '^from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && errors=$((errors+1)) -git grep '^from langchain' langchain/schema | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1)) +git grep '^from langchain' langchain/schema | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|env)' && errors=$((errors+1)) git grep '^from langchain' langchain/adapters | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1)) git grep '^from langchain' langchain/callbacks | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env)' && errors=$((errors+1)) # TODO: it's probably not amazing so that so many other modules depend on `langchain.utilities`, because there can be a lot of imports there diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py index 162c65a9b17..a737b98dcff 100644 --- a/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py @@ -10,15 +10,15 @@ from freezegun import freeze_time from langchain.callbacks.manager import CallbackManager from langchain.callbacks.tracers.base import BaseTracer, TracerException -from langchain.callbacks.tracers.langchain_v1 import ( +from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base +from langchain.schema import LLMResult +from langchain.schema.callbacks.tracers.langchain_v1 import ( ChainRun, LangChainTracerV1, LLMRun, ToolRun, TracerSessionV1, ) -from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base -from langchain.schema import LLMResult from langchain.schema.messages import HumanMessage TEST_SESSION_ID = 2023 diff --git a/libs/langchain/tests/unit_tests/test_globals.py b/libs/langchain/tests/unit_tests/test_globals.py index 76b9d437e03..8249d39d65d 100644 --- a/libs/langchain/tests/unit_tests/test_globals.py +++ b/libs/langchain/tests/unit_tests/test_globals.py @@ -3,7 +3,7 @@ from langchain.globals import get_debug, get_verbose, set_debug, set_verbose def test_debug_is_settable_directly() -> None: import langchain - from langchain.callbacks.manager import _get_debug + from langchain.schema.callbacks.manager import _get_debug previous_value = langchain.debug previous_fn_reading = _get_debug() @@ -33,7 +33,7 @@ def test_debug_is_settable_directly() -> None: def test_debug_is_settable_via_setter() -> None: from langchain import globals - from langchain.callbacks.manager import _get_debug + from langchain.schema.callbacks.manager import _get_debug previous_value = globals._debug previous_fn_reading = _get_debug()