mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
4 Commits
replace_ap
...
eugene/tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05c457a994 | ||
|
|
956a4489e3 | ||
|
|
d7f7c96f95 | ||
|
|
0384093fcf |
162
libs/core/langchain_core/callbacks/events.py
Normal file
162
libs/core/langchain_core/callbacks/events.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import Optional, Dict, Literal, List
|
||||
from typing import TypedDict, Any, Union, Callable
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import GenerationChunk, LLMResult, ChatGenerationChunk
|
||||
|
||||
|
||||
class RetrieverErrorEvent(TypedDict):
|
||||
type: Literal["on_retriever_error"]
|
||||
error: BaseException
|
||||
|
||||
|
||||
class RetrieverEndEvent(TypedDict):
|
||||
type: Literal["on_retriever_end"]
|
||||
documents: List[Document]
|
||||
|
||||
|
||||
class LLMNewTokenEvent(TypedDict):
|
||||
type: Literal["on_llm_new_token"]
|
||||
token: str
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]]
|
||||
|
||||
|
||||
class LLMEndEvent(TypedDict):
|
||||
type: Literal["on_llm_end"]
|
||||
response: LLMResult
|
||||
|
||||
|
||||
class LLMErrorEvent(TypedDict):
|
||||
type: Literal["on_llm_error"]
|
||||
error: BaseException
|
||||
|
||||
|
||||
class ChainEndEvent(TypedDict):
|
||||
type: Literal["on_chain_end"]
|
||||
outputs: Dict[str, Any]
|
||||
|
||||
|
||||
class ChainErrorEvent(TypedDict):
|
||||
"""Event for a chain error."""
|
||||
|
||||
type: Literal["on_chain_error"]
|
||||
error: BaseException
|
||||
|
||||
|
||||
class AgentActionEvent(TypedDict):
|
||||
"""Event for an agent action."""
|
||||
|
||||
type: Literal["on_agent_action"]
|
||||
action: AgentAction
|
||||
|
||||
|
||||
class AgentFinishEvent(TypedDict):
|
||||
"""Event for an agent action."""
|
||||
|
||||
type: Literal["on_agent_finish"]
|
||||
finish: AgentFinish
|
||||
|
||||
|
||||
class ToolEndEvent(TypedDict):
|
||||
"""Event for a tool end."""
|
||||
|
||||
type: Literal["on_tool_end"]
|
||||
output: Any
|
||||
|
||||
|
||||
class ToolErrorEvent(TypedDict):
|
||||
"""Event for a tool error."""
|
||||
|
||||
type: Literal["on_tool_error"]
|
||||
error: BaseException
|
||||
|
||||
|
||||
class LLMStartEvent(TypedDict):
|
||||
type: Literal["on_llm_start"]
|
||||
serialized: Dict[str, Any]
|
||||
prompts: List[str]
|
||||
|
||||
|
||||
class ChatModelStartEvent(TypedDict):
|
||||
type: Literal["on_chat_model_start"]
|
||||
serialized: Dict[str, Any]
|
||||
messages: List[List[BaseMessage]]
|
||||
|
||||
|
||||
class AdHocEvent(TypedDict):
|
||||
"""Ad hoc event."""
|
||||
type: Literal["on_ad_hoc"]
|
||||
data: Any
|
||||
|
||||
|
||||
class RetrieverStartEvent(TypedDict):
|
||||
type: Literal["on_retriever_start"]
|
||||
serialized: Dict[str, Any]
|
||||
query: str
|
||||
|
||||
|
||||
class ChainStartEvent(TypedDict):
|
||||
type: Literal["on_chain_start"]
|
||||
serialized: Dict[str, Any]
|
||||
inputs: Dict[str, Any]
|
||||
|
||||
|
||||
class ToolStartEvent(TypedDict):
|
||||
type: Literal["on_tool_start"]
|
||||
serialized: Dict[str, Any]
|
||||
input_str: str
|
||||
inputs: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
class TextEvent(TypedDict):
|
||||
type: Literal["on_text"]
|
||||
text: str
|
||||
|
||||
|
||||
class RetryEvent(TypedDict):
|
||||
type: Literal["on_retry"]
|
||||
retry_state: RetryCallState
|
||||
|
||||
|
||||
# define possible callback events
|
||||
Event = Union[
|
||||
RetrieverErrorEvent,
|
||||
RetrieverEndEvent,
|
||||
LLMNewTokenEvent,
|
||||
LLMEndEvent,
|
||||
LLMErrorEvent,
|
||||
ChainEndEvent,
|
||||
ChainErrorEvent,
|
||||
AgentActionEvent,
|
||||
AgentFinishEvent,
|
||||
ToolEndEvent,
|
||||
ToolErrorEvent,
|
||||
LLMStartEvent,
|
||||
ChatModelStartEvent,
|
||||
RetrieverStartEvent,
|
||||
ChainStartEvent,
|
||||
ToolStartEvent,
|
||||
TextEvent,
|
||||
RetryEvent,
|
||||
]
|
||||
|
||||
|
||||
class Handlers(TypedDict, total=False):
|
||||
on_chain_error: Union[Callable[[ChainErrorEvent], Any]]
|
||||
on_chain_start: Union[Callable[[ChainStartEvent], Any]]
|
||||
|
||||
|
||||
def func(inputs: Any, callbacks: Optional[Handlers]):
|
||||
pass
|
||||
|
||||
|
||||
def _on_chain_error(event: ChainErrorEvent):
|
||||
pass
|
||||
|
||||
|
||||
def zoo(inputs: Any):
|
||||
func(inputs, callbacks={"on_chain_error": _on_chain_error})
|
||||
@@ -259,6 +259,11 @@ def handle_event(
|
||||
"""
|
||||
coros: List[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
# if '_event' in kwargs:
|
||||
# event = kwargs['_event']
|
||||
#
|
||||
#
|
||||
|
||||
try:
|
||||
message_strings: Optional[List[str]] = None
|
||||
for handler in handlers:
|
||||
@@ -353,6 +358,10 @@ async def _ahandle_event_for_handler(
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
from langchain_core.callbacks.schema import GenericCallbackHandler
|
||||
if isinstance(handler, GenericCallbackHandler):
|
||||
raise ValueError("here")
|
||||
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
event = getattr(handler, event_name)
|
||||
@@ -2011,6 +2020,7 @@ def _configure(
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
) -> T:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@@ -2038,10 +2048,11 @@ def _configure(
|
||||
_tracing_v2_is_enabled,
|
||||
tracing_v2_callback_var,
|
||||
)
|
||||
run_id = run_id or uuid.uuid4()
|
||||
|
||||
run_tree = get_run_tree_context()
|
||||
# This can pick up run tree context from trace(), but not traceable?
|
||||
run_tree = get_run_tree_context() # Why is this here?
|
||||
parent_run_id = None if run_tree is None else run_tree.id
|
||||
callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id)
|
||||
if inheritable_callbacks or local_callbacks:
|
||||
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
||||
inheritable_callbacks_ = inheritable_callbacks or []
|
||||
@@ -2049,8 +2060,9 @@ def _configure(
|
||||
handlers=inheritable_callbacks_.copy(),
|
||||
inheritable_handlers=inheritable_callbacks_.copy(),
|
||||
parent_run_id=parent_run_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
else:
|
||||
elif isinstance(inheritable_callbacks, BaseCallbackManager):
|
||||
parent_run_id_ = inheritable_callbacks.parent_run_id
|
||||
# Break ties between the external tracing context and inherited context
|
||||
if parent_run_id is not None:
|
||||
@@ -2071,6 +2083,12 @@ def _configure(
|
||||
inheritable_tags=inheritable_callbacks.inheritable_tags.copy(),
|
||||
metadata=inheritable_callbacks.metadata.copy(),
|
||||
inheritable_metadata=inheritable_callbacks.inheritable_metadata.copy(),
|
||||
run_id=run_id,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"inheritable_callbacks must be a list or CallbackManager."
|
||||
f"Got {type(inheritable_callbacks)}"
|
||||
)
|
||||
local_handlers_ = (
|
||||
local_callbacks
|
||||
@@ -2079,6 +2097,10 @@ def _configure(
|
||||
)
|
||||
for handler in local_handlers_:
|
||||
callback_manager.add_handler(handler, False)
|
||||
else:
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=[], parent_run_id=parent_run_id
|
||||
)
|
||||
if inheritable_tags or local_tags:
|
||||
callback_manager.add_tags(inheritable_tags or [])
|
||||
callback_manager.add_tags(local_tags or [], False)
|
||||
@@ -2124,6 +2146,7 @@ def _configure(
|
||||
isinstance(handler, LangChainTracer)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
# Add LangChain tracer if it's not already present in the handlers
|
||||
if tracer_v2:
|
||||
callback_manager.add_handler(tracer_v2, True)
|
||||
else:
|
||||
@@ -2142,12 +2165,15 @@ def _configure(
|
||||
)
|
||||
if run_tree is not None:
|
||||
for handler in callback_manager.handlers:
|
||||
# Looks like a hack here? Why is this needed?
|
||||
if isinstance(handler, LangChainTracer):
|
||||
handler.order_map[run_tree.id] = (
|
||||
run_tree.trace_id,
|
||||
run_tree.dotted_order,
|
||||
)
|
||||
handler.run_map[str(run_tree.id)] = cast(Run, run_tree)
|
||||
|
||||
# Look at callback managers exposed as _configure_hooks
|
||||
for var, inheritable, handler_class, env_var in _configure_hooks:
|
||||
create_one = (
|
||||
env_var is not None
|
||||
|
||||
202
libs/core/langchain_core/callbacks/schema.py
Normal file
202
libs/core/langchain_core/callbacks/schema.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import abc
|
||||
from typing import Generic, TypeVar, Dict, Any, Union
|
||||
from typing import TypedDict, NotRequired, List, Optional, Literal, Set, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import CallbackManager, BaseCallbackHandler
|
||||
from langchain_core.callbacks.events import Event, ChainStartEvent, ChainEndEvent, ChainErrorEvent
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
CallbackEvent = Literal[
|
||||
"on_chat_model_start",
|
||||
"on_chat_model_end",
|
||||
"on_chat_model_error",
|
||||
"on_llm_start",
|
||||
"on_llm_end",
|
||||
"on_llm_error",
|
||||
"on_chain_start",
|
||||
"on_chain_end",
|
||||
"on_chain_error",
|
||||
"on_tool_start",
|
||||
"on_tool_end",
|
||||
"on_tool_error",
|
||||
"on_retriever_start",
|
||||
"on_retriever_end",
|
||||
"on_retriever_error",
|
||||
# Where do these come from??!
|
||||
"on_prompt_start",
|
||||
"on_prompt_end",
|
||||
# Streaming events are missing for the most part
|
||||
# "on_chain_stream",
|
||||
"on_llm_new_token", # TODO: This should be updated!
|
||||
]
|
||||
|
||||
|
||||
class BaseCallback(TypedDict):
|
||||
"""Base event."""
|
||||
|
||||
# id: str # id for the callback itself
|
||||
run_id: str # id for the run that generated the callback
|
||||
tags: NotRequired[List[str]]
|
||||
metadata: NotRequired[List[str]]
|
||||
parent_id: NotRequired[Optional[str]]
|
||||
type: CallbackEvent
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseCallback)
|
||||
|
||||
|
||||
def _convert_event_to_callback(
|
||||
event: Event,
|
||||
*,
|
||||
run_id: str,
|
||||
tags: Optional[str],
|
||||
metadata: Optional,
|
||||
parent_id: Optional[str],
|
||||
) -> BaseCallback:
|
||||
"""Convert an event to a callback."""
|
||||
return cast(
|
||||
BaseCallback,
|
||||
{
|
||||
"run_id": run_id,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"parent_id": parent_id,
|
||||
**event,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class GenericCallbackHandler(abc.ABC, Generic[T]):
|
||||
# @abc.abstractmethod
|
||||
@property
|
||||
def accepts_events(self) -> Optional[Set[CallbackEvent]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def handle_callback(self, callback: BaseCallback) -> T:
|
||||
"""Handle an event."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def ahandle_callback(self, callback: BaseCallback) -> T:
|
||||
"""Handle an event asynchronously."""
|
||||
|
||||
|
||||
# TODO(Eugene): This inherits from a bunch of stuff
|
||||
# for backwards compatibility reasons, but prior to merging
|
||||
# we need to clean a bunch of those stuff.
|
||||
class CallbackDispatcher(CallbackManager):
|
||||
"""Interface to dispatch callbacks to all registered handlers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
handlers: List[Union[BaseCallbackHandler, GenericCallbackHandler]],
|
||||
inheritable_handlers: Optional[
|
||||
List[Union[BaseCallbackHandler, GenericCallbackHandler]]
|
||||
] = None,
|
||||
parent_run_id: Optional[str],
|
||||
run_id: Optional[str],
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
handlers=handlers,
|
||||
inheritable_handlers=inheritable_handlers,
|
||||
parent_run_id=parent_run_id,
|
||||
tags=tags,
|
||||
inheritable_tags=inheritable_tags,
|
||||
metadata=metadata,
|
||||
inheritable_metadata=inheritable_metadata,
|
||||
)
|
||||
self.run_id = run_id
|
||||
|
||||
def dispatch_event(self, event: Event) -> None:
|
||||
"""Handle an event."""
|
||||
# Delegate to handle event
|
||||
callback_event = _convert_event_to_callback(
|
||||
event,
|
||||
run_id=self.run_id,
|
||||
parent_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
# handle_event(
|
||||
# self.handlers,
|
||||
# event_name=event["type"],
|
||||
# ignore_condition_name=None,
|
||||
# _event=callback_event,
|
||||
# )
|
||||
# if isinstance(handler, GenericCallbackHandler):
|
||||
# if callback_event["type"] not in handler.accepts_events:
|
||||
# continue
|
||||
# handler.handle_callback(callback_with_data)
|
||||
# else:
|
||||
# handler.handle_callback(callback_with_data)
|
||||
#
|
||||
|
||||
async def adispatch_event(self, event: Event) -> None:
|
||||
"""Delegate to the handler"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Handle a chain start event."""
|
||||
event: ChainStartEvent = {
|
||||
"type": "on_chain_start",
|
||||
"serialized": serialized,
|
||||
"inputs": inputs,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
self.dispatch_event(event)
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event: ChainEndEvent = {
|
||||
"type": "on_chain_end",
|
||||
"kwargs": kwargs
|
||||
}
|
||||
self.dispatch_event(event)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event: ChainErrorEvent = {
|
||||
"type": "on_chain_error",
|
||||
"error": error,
|
||||
"kwargs": kwargs
|
||||
}
|
||||
self.dispatch_event(event)
|
||||
|
||||
def get_child(
|
||||
self,
|
||||
*,
|
||||
run_id: Optional[str] = None, # Allow overriding the run_id?
|
||||
) -> "CallbackDispatcher":
|
||||
"""Get a child."""
|
||||
# unpack config and populate stuff from it
|
||||
return CallbackDispatcher(
|
||||
parent_run_id=self.run_id,
|
||||
run_id=run_id,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
handlers=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def config(cls, *, config: Optional[RunnableConfig] = None) -> "CallbackDispatcher":
|
||||
"""Configure callbacks."""
|
||||
return cls(parent_run_id=None, run_id=config.run_id)
|
||||
@@ -1577,7 +1577,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
@@ -1585,7 +1585,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
child_config = patch_config(config, callbacks=callback_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
output = cast(
|
||||
@@ -1595,15 +1595,15 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
func, # type: ignore[arg-type]
|
||||
input, # type: ignore[arg-type]
|
||||
config,
|
||||
run_manager,
|
||||
callback_manager,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
callback_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(output)
|
||||
callback_manager.on_chain_end(output)
|
||||
return output
|
||||
|
||||
async def _acall_with_config(
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.callbacks.schema import CallbackDispatcher
|
||||
else:
|
||||
# Pydantic validates through typed dicts, but
|
||||
# the callbacks need forward refs updated
|
||||
@@ -421,7 +422,7 @@ def acall_func_with_variable_args(
|
||||
return func(input, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackDispatcher:
|
||||
"""Get a callback manager for a config.
|
||||
|
||||
Args:
|
||||
@@ -430,9 +431,9 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||
Returns:
|
||||
CallbackManager: The callback manager.
|
||||
"""
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
from langchain_core.callbacks.schema import CallbackDispatcher
|
||||
|
||||
return CallbackManager.configure(
|
||||
return CallbackDispatcher.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
|
||||
Reference in New Issue
Block a user