Compare commits

...

4 Commits

Author SHA1 Message Date
Eugene Yurtsev
05c457a994 qxqx 2024-07-08 10:57:29 -04:00
Eugene Yurtsev
956a4489e3 Merge branch 'master' into eugene/tracing_interop2 2024-07-07 14:33:13 -04:00
Eugene Yurtsev
d7f7c96f95 qxqx 2024-07-01 13:52:51 -04:00
Eugene Yurtsev
0384093fcf qxqx 2024-07-01 13:24:10 -04:00
5 changed files with 402 additions and 11 deletions

View File

@@ -0,0 +1,162 @@
from typing import Optional, Dict, Literal, List
from typing import TypedDict, Any, Union, Callable
from tenacity import RetryCallState
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.outputs import GenerationChunk, LLMResult, ChatGenerationChunk
class RetrieverErrorEvent(TypedDict):
type: Literal["on_retriever_error"]
error: BaseException
class RetrieverEndEvent(TypedDict):
type: Literal["on_retriever_end"]
documents: List[Document]
class LLMNewTokenEvent(TypedDict):
type: Literal["on_llm_new_token"]
token: str
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]]
class LLMEndEvent(TypedDict):
type: Literal["on_llm_end"]
response: LLMResult
class LLMErrorEvent(TypedDict):
type: Literal["on_llm_error"]
error: BaseException
class ChainEndEvent(TypedDict):
type: Literal["on_chain_end"]
outputs: Dict[str, Any]
class ChainErrorEvent(TypedDict):
"""Event for a chain error."""
type: Literal["on_chain_error"]
error: BaseException
class AgentActionEvent(TypedDict):
"""Event for an agent action."""
type: Literal["on_agent_action"]
action: AgentAction
class AgentFinishEvent(TypedDict):
"""Event for an agent action."""
type: Literal["on_agent_finish"]
finish: AgentFinish
class ToolEndEvent(TypedDict):
"""Event for a tool end."""
type: Literal["on_tool_end"]
output: Any
class ToolErrorEvent(TypedDict):
"""Event for a tool error."""
type: Literal["on_tool_error"]
error: BaseException
class LLMStartEvent(TypedDict):
type: Literal["on_llm_start"]
serialized: Dict[str, Any]
prompts: List[str]
class ChatModelStartEvent(TypedDict):
type: Literal["on_chat_model_start"]
serialized: Dict[str, Any]
messages: List[List[BaseMessage]]
class AdHocEvent(TypedDict):
"""Ad hoc event."""
type: Literal["on_ad_hoc"]
data: Any
class RetrieverStartEvent(TypedDict):
type: Literal["on_retriever_start"]
serialized: Dict[str, Any]
query: str
class ChainStartEvent(TypedDict):
type: Literal["on_chain_start"]
serialized: Dict[str, Any]
inputs: Dict[str, Any]
class ToolStartEvent(TypedDict):
type: Literal["on_tool_start"]
serialized: Dict[str, Any]
input_str: str
inputs: Optional[Dict[str, Any]]
class TextEvent(TypedDict):
type: Literal["on_text"]
text: str
class RetryEvent(TypedDict):
type: Literal["on_retry"]
retry_state: RetryCallState
# define possible callback events
Event = Union[
RetrieverErrorEvent,
RetrieverEndEvent,
LLMNewTokenEvent,
LLMEndEvent,
LLMErrorEvent,
ChainEndEvent,
ChainErrorEvent,
AgentActionEvent,
AgentFinishEvent,
ToolEndEvent,
ToolErrorEvent,
LLMStartEvent,
ChatModelStartEvent,
RetrieverStartEvent,
ChainStartEvent,
ToolStartEvent,
TextEvent,
RetryEvent,
]
class Handlers(TypedDict, total=False):
on_chain_error: Union[Callable[[ChainErrorEvent], Any]]
on_chain_start: Union[Callable[[ChainStartEvent], Any]]
def func(inputs: Any, callbacks: Optional[Handlers]):
pass
def _on_chain_error(event: ChainErrorEvent):
pass
def zoo(inputs: Any):
func(inputs, callbacks={"on_chain_error": _on_chain_error})

View File

@@ -259,6 +259,11 @@ def handle_event(
"""
coros: List[Coroutine[Any, Any, Any]] = []
# if '_event' in kwargs:
# event = kwargs['_event']
#
#
try:
message_strings: Optional[List[str]] = None
for handler in handlers:
@@ -353,6 +358,10 @@ async def _ahandle_event_for_handler(
*args: Any,
**kwargs: Any,
) -> None:
from langchain_core.callbacks.schema import GenericCallbackHandler
if isinstance(handler, GenericCallbackHandler):
raise ValueError("here")
try:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
event = getattr(handler, event_name)
@@ -2011,6 +2020,7 @@ def _configure(
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
) -> T:
"""Configure the callback manager.
@@ -2038,10 +2048,11 @@ def _configure(
_tracing_v2_is_enabled,
tracing_v2_callback_var,
)
run_id = run_id or uuid.uuid4()
run_tree = get_run_tree_context()
# This can pick up run tree context from trace(), but not traceable?
run_tree = get_run_tree_context() # Why is this here?
parent_run_id = None if run_tree is None else run_tree.id
callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id)
if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
inheritable_callbacks_ = inheritable_callbacks or []
@@ -2049,8 +2060,9 @@ def _configure(
handlers=inheritable_callbacks_.copy(),
inheritable_handlers=inheritable_callbacks_.copy(),
parent_run_id=parent_run_id,
run_id=run_id,
)
else:
elif isinstance(inheritable_callbacks, BaseCallbackManager):
parent_run_id_ = inheritable_callbacks.parent_run_id
# Break ties between the external tracing context and inherited context
if parent_run_id is not None:
@@ -2071,6 +2083,12 @@ def _configure(
inheritable_tags=inheritable_callbacks.inheritable_tags.copy(),
metadata=inheritable_callbacks.metadata.copy(),
inheritable_metadata=inheritable_callbacks.inheritable_metadata.copy(),
run_id=run_id,
)
else:
raise TypeError(
f"inheritable_callbacks must be a list or CallbackManager."
f"Got {type(inheritable_callbacks)}"
)
local_handlers_ = (
local_callbacks
@@ -2079,6 +2097,10 @@ def _configure(
)
for handler in local_handlers_:
callback_manager.add_handler(handler, False)
else:
callback_manager = callback_manager_cls(
handlers=[], parent_run_id=parent_run_id
)
if inheritable_tags or local_tags:
callback_manager.add_tags(inheritable_tags or [])
callback_manager.add_tags(local_tags or [], False)
@@ -2124,6 +2146,7 @@ def _configure(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
# Add LangChain tracer if it's not already present in the handlers
if tracer_v2:
callback_manager.add_handler(tracer_v2, True)
else:
@@ -2142,12 +2165,15 @@ def _configure(
)
if run_tree is not None:
for handler in callback_manager.handlers:
# Looks like a hack here? Why is this needed?
if isinstance(handler, LangChainTracer):
handler.order_map[run_tree.id] = (
run_tree.trace_id,
run_tree.dotted_order,
)
handler.run_map[str(run_tree.id)] = cast(Run, run_tree)
# Look at callback managers exposed as _configure_hooks
for var, inheritable, handler_class, env_var in _configure_hooks:
create_one = (
env_var is not None

View File

@@ -0,0 +1,202 @@
import abc
from typing import Generic, TypeVar, Dict, Any, Union
from typing import TypedDict, NotRequired, List, Optional, Literal, Set, cast
from uuid import UUID
from langchain_core.callbacks import CallbackManager, BaseCallbackHandler
from langchain_core.callbacks.events import Event, ChainStartEvent, ChainEndEvent, ChainErrorEvent
from langchain_core.runnables import RunnableConfig
CallbackEvent = Literal[
"on_chat_model_start",
"on_chat_model_end",
"on_chat_model_error",
"on_llm_start",
"on_llm_end",
"on_llm_error",
"on_chain_start",
"on_chain_end",
"on_chain_error",
"on_tool_start",
"on_tool_end",
"on_tool_error",
"on_retriever_start",
"on_retriever_end",
"on_retriever_error",
# Where do these come from??!
"on_prompt_start",
"on_prompt_end",
# Streaming events are missing for the most part
# "on_chain_stream",
"on_llm_new_token", # TODO: This should be updated!
]
class BaseCallback(TypedDict):
"""Base event."""
# id: str # id for the callback itself
run_id: str # id for the run that generated the callback
tags: NotRequired[List[str]]
metadata: NotRequired[List[str]]
parent_id: NotRequired[Optional[str]]
type: CallbackEvent
T = TypeVar("T", bound=BaseCallback)
def _convert_event_to_callback(
event: Event,
*,
run_id: str,
tags: Optional[str],
metadata: Optional,
parent_id: Optional[str],
) -> BaseCallback:
"""Convert an event to a callback."""
return cast(
BaseCallback,
{
"run_id": run_id,
"tags": tags,
"metadata": metadata,
"parent_id": parent_id,
**event,
},
)
T = TypeVar("T")
class GenericCallbackHandler(abc.ABC, Generic[T]):
# @abc.abstractmethod
@property
def accepts_events(self) -> Optional[Set[CallbackEvent]]:
raise NotImplementedError()
@abc.abstractmethod
def handle_callback(self, callback: BaseCallback) -> T:
"""Handle an event."""
@abc.abstractmethod
async def ahandle_callback(self, callback: BaseCallback) -> T:
"""Handle an event asynchronously."""
# TODO(Eugene): This inherits from a bunch of stuff
# for backwards compatibility reasons, but prior to merging
# we need to clean a bunch of those stuff.
class CallbackDispatcher(CallbackManager):
"""Interface to dispatch callbacks to all registered handlers."""
def __init__(
self,
*,
handlers: List[Union[BaseCallbackHandler, GenericCallbackHandler]],
inheritable_handlers: Optional[
List[Union[BaseCallbackHandler, GenericCallbackHandler]]
] = None,
parent_run_id: Optional[str],
run_id: Optional[str],
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
handlers=handlers,
inheritable_handlers=inheritable_handlers,
parent_run_id=parent_run_id,
tags=tags,
inheritable_tags=inheritable_tags,
metadata=metadata,
inheritable_metadata=inheritable_metadata,
)
self.run_id = run_id
def dispatch_event(self, event: Event) -> None:
"""Handle an event."""
# Delegate to handle event
callback_event = _convert_event_to_callback(
event,
run_id=self.run_id,
parent_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
)
# handle_event(
# self.handlers,
# event_name=event["type"],
# ignore_condition_name=None,
# _event=callback_event,
# )
# if isinstance(handler, GenericCallbackHandler):
# if callback_event["type"] not in handler.accepts_events:
# continue
# handler.handle_callback(callback_with_data)
# else:
# handler.handle_callback(callback_with_data)
#
async def adispatch_event(self, event: Event) -> None:
"""Delegate to the handler"""
raise NotImplementedError()
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Union[Dict[str, Any], Any],
**kwargs: Any,
) -> None:
"""Handle a chain start event."""
event: ChainStartEvent = {
"type": "on_chain_start",
"serialized": serialized,
"inputs": inputs,
"kwargs": kwargs,
}
self.dispatch_event(event)
def on_chain_end(
self,
outputs: Dict[str, Any],
**kwargs: Any,
) -> None:
event: ChainEndEvent = {
"type": "on_chain_end",
"kwargs": kwargs
}
self.dispatch_event(event)
def on_chain_error(
self,
error: BaseException,
**kwargs: Any,
) -> None:
event: ChainErrorEvent = {
"type": "on_chain_error",
"error": error,
"kwargs": kwargs
}
self.dispatch_event(event)
def get_child(
self,
*,
run_id: Optional[str] = None, # Allow overriding the run_id?
) -> "CallbackDispatcher":
"""Get a child."""
# unpack config and populate stuff from it
return CallbackDispatcher(
parent_run_id=self.run_id,
run_id=run_id,
inheritable_handlers=self.inheritable_handlers,
handlers=[],
)
@classmethod
def config(cls, *, config: Optional[RunnableConfig] = None) -> "CallbackDispatcher":
"""Configure callbacks."""
return cls(parent_run_id=None, run_id=config.run_id)

View File

@@ -1577,7 +1577,7 @@ class Runnable(Generic[Input, Output], ABC):
with callbacks. Use this method to implement invoke() in subclasses."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
callback_manager.on_chain_start(
dumpd(self),
input,
run_type=run_type,
@@ -1585,7 +1585,7 @@ class Runnable(Generic[Input, Output], ABC):
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
child_config = patch_config(config, callbacks=callback_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
output = cast(
@@ -1595,15 +1595,15 @@ class Runnable(Generic[Input, Output], ABC):
func, # type: ignore[arg-type]
input, # type: ignore[arg-type]
config,
run_manager,
callback_manager,
**kwargs,
),
)
except BaseException as e:
run_manager.on_chain_error(e)
callback_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
callback_manager.on_chain_end(output)
return output
async def _acall_with_config(

View File

@@ -41,6 +41,7 @@ if TYPE_CHECKING:
CallbackManager,
CallbackManagerForChainRun,
)
from langchain_core.callbacks.schema import CallbackDispatcher
else:
# Pydantic validates through typed dicts, but
# the callbacks need forward refs updated
@@ -421,7 +422,7 @@ def acall_func_with_variable_args(
return func(input, **kwargs) # type: ignore[call-arg]
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackDispatcher:
"""Get a callback manager for a config.
Args:
@@ -430,9 +431,9 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
Returns:
CallbackManager: The callback manager.
"""
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.callbacks.schema import CallbackDispatcher
return CallbackManager.configure(
return CallbackDispatcher.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),