mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
Add support for tags in chain group context manager (#6668)
Lets you specify local and inheritable tags in the group manager. Also, add more verbose docstrings for our reference docs.
This commit is contained in:
parent
d1bcc58beb
commit
9ca3b4645e
@ -74,7 +74,16 @@ def _get_debug() -> bool:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||||
"""Get OpenAI callback handler in a context manager."""
|
"""Get the OpenAI callback handler in a context manager.
|
||||||
|
which conveniently exposes token and cost information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAICallbackHandler: The OpenAI callback handler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with get_openai_callback() as cb:
|
||||||
|
... # Use the OpenAI callback handler
|
||||||
|
"""
|
||||||
cb = OpenAICallbackHandler()
|
cb = OpenAICallbackHandler()
|
||||||
openai_callback_var.set(cb)
|
openai_callback_var.set(cb)
|
||||||
yield cb
|
yield cb
|
||||||
@ -85,7 +94,19 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
|||||||
def tracing_enabled(
|
def tracing_enabled(
|
||||||
session_name: str = "default",
|
session_name: str = "default",
|
||||||
) -> Generator[TracerSessionV1, None, None]:
|
) -> Generator[TracerSessionV1, None, None]:
|
||||||
"""Get Tracer in a context manager."""
|
"""Get the Deprecated LangChainTracer in a context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_name (str, optional): The name of the session.
|
||||||
|
Defaults to "default".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TracerSessionV1: The LangChainTracer session.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with tracing_enabled() as session:
|
||||||
|
... # Use the LangChainTracer session
|
||||||
|
"""
|
||||||
cb = LangChainTracerV1()
|
cb = LangChainTracerV1()
|
||||||
session = cast(TracerSessionV1, cb.load_session(session_name))
|
session = cast(TracerSessionV1, cb.load_session(session_name))
|
||||||
tracing_callback_var.set(cb)
|
tracing_callback_var.set(cb)
|
||||||
@ -97,7 +118,19 @@ def tracing_enabled(
|
|||||||
def wandb_tracing_enabled(
|
def wandb_tracing_enabled(
|
||||||
session_name: str = "default",
|
session_name: str = "default",
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""Get WandbTracer in a context manager."""
|
"""Get the WandbTracer in a context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_name (str, optional): The name of the session.
|
||||||
|
Defaults to "default".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with wandb_tracing_enabled() as session:
|
||||||
|
... # Use the WandbTracer session
|
||||||
|
"""
|
||||||
cb = WandbTracer()
|
cb = WandbTracer()
|
||||||
wandb_tracing_callback_var.set(cb)
|
wandb_tracing_callback_var.set(cb)
|
||||||
yield None
|
yield None
|
||||||
@ -110,7 +143,21 @@ def tracing_v2_enabled(
|
|||||||
*,
|
*,
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""Get the experimental tracer handler in a context manager."""
|
"""Instruct LangChain to log all runs in context to LangSmith.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_name (str, optional): The name of the project.
|
||||||
|
Defaults to "default".
|
||||||
|
example_id (str or UUID, optional): The ID of the example.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with tracing_v2_enabled():
|
||||||
|
... # LangChain code will automatically be traced
|
||||||
|
"""
|
||||||
# Issue a warning that this is experimental
|
# Issue a warning that this is experimental
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The tracing v2 API is in development. "
|
"The tracing v2 API is in development. "
|
||||||
@ -133,14 +180,36 @@ def trace_as_chain_group(
|
|||||||
*,
|
*,
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
) -> Generator[CallbackManager, None, None]:
|
) -> Generator[CallbackManager, None, None]:
|
||||||
"""Get a callback manager for a chain group in a context manager."""
|
"""Get a callback manager for a chain group in a context manager.
|
||||||
|
Useful for grouping different calls together as a single run even if
|
||||||
|
they aren't composed in a single chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_name (str): The name of the chain group.
|
||||||
|
project_name (str, optional): The name of the project.
|
||||||
|
Defaults to None.
|
||||||
|
example_id (str or UUID, optional): The ID of the example.
|
||||||
|
Defaults to None.
|
||||||
|
tags (List[str], optional): The inheritable tags to apply to all runs.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManager: The callback manager for the chain group.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with trace_as_chain_group("group_name") as manager:
|
||||||
|
... # Use the callback manager for the chain group
|
||||||
|
... llm.predict("Foo", callbacks=manager)
|
||||||
|
"""
|
||||||
cb = LangChainTracer(
|
cb = LangChainTracer(
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
example_id=example_id,
|
example_id=example_id,
|
||||||
)
|
)
|
||||||
cm = CallbackManager.configure(
|
cm = CallbackManager.configure(
|
||||||
inheritable_callbacks=[cb],
|
inheritable_callbacks=[cb],
|
||||||
|
inheritable_tags=tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
run_manager = cm.on_chain_start({"name": group_name}, {})
|
run_manager = cm.on_chain_start({"name": group_name}, {})
|
||||||
@ -154,14 +223,34 @@ async def atrace_as_chain_group(
|
|||||||
*,
|
*,
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
) -> AsyncGenerator[AsyncCallbackManager, None]:
|
) -> AsyncGenerator[AsyncCallbackManager, None]:
|
||||||
"""Get a callback manager for a chain group in a context manager."""
|
"""Get an async callback manager for a chain group in a context manager.
|
||||||
|
Useful for grouping different async calls together as a single run even if
|
||||||
|
they aren't composed in a single chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_name (str): The name of the chain group.
|
||||||
|
project_name (str, optional): The name of the project.
|
||||||
|
Defaults to None.
|
||||||
|
example_id (str or UUID, optional): The ID of the example.
|
||||||
|
Defaults to None.
|
||||||
|
tags (List[str], optional): The inheritable tags to apply to all runs.
|
||||||
|
Defaults to None.
|
||||||
|
Returns:
|
||||||
|
AsyncCallbackManager: The async callback manager for the chain group.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> async with atrace_as_chain_group("group_name") as manager:
|
||||||
|
... # Use the async callback manager for the chain group
|
||||||
|
... await llm.apredict("Foo", callbacks=manager)
|
||||||
|
"""
|
||||||
cb = LangChainTracer(
|
cb = LangChainTracer(
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
example_id=example_id,
|
example_id=example_id,
|
||||||
)
|
)
|
||||||
cm = AsyncCallbackManager.configure(
|
cm = AsyncCallbackManager.configure(
|
||||||
inheritable_callbacks=[cb],
|
inheritable_callbacks=[cb], inheritable_tags=tags
|
||||||
)
|
)
|
||||||
|
|
||||||
run_manager = await cm.on_chain_start({"name": group_name}, {})
|
run_manager = await cm.on_chain_start({"name": group_name}, {})
|
||||||
@ -293,7 +382,18 @@ class BaseRunManager(RunManagerMixin):
|
|||||||
tags: List[str],
|
tags: List[str],
|
||||||
inheritable_tags: List[str],
|
inheritable_tags: List[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize run manager."""
|
"""Initialize the run manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id (UUID): The ID of the run.
|
||||||
|
handlers (List[BaseCallbackHandler]): The list of handlers.
|
||||||
|
inheritable_handlers (List[BaseCallbackHandler]):
|
||||||
|
The list of inheritable handlers.
|
||||||
|
parent_run_id (UUID, optional): The ID of the parent run.
|
||||||
|
Defaults to None.
|
||||||
|
tags (List[str]): The list of tags.
|
||||||
|
inheritable_tags (List[str]): The list of inheritable tags.
|
||||||
|
"""
|
||||||
self.run_id = run_id
|
self.run_id = run_id
|
||||||
self.handlers = handlers
|
self.handlers = handlers
|
||||||
self.inheritable_handlers = inheritable_handlers
|
self.inheritable_handlers = inheritable_handlers
|
||||||
@ -303,7 +403,11 @@ class BaseRunManager(RunManagerMixin):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
||||||
"""Return a manager that doesn't perform any operations."""
|
"""Return a manager that doesn't perform any operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseRunManager: The noop manager.
|
||||||
|
"""
|
||||||
return cls(
|
return cls(
|
||||||
run_id=uuid4(),
|
run_id=uuid4(),
|
||||||
handlers=[],
|
handlers=[],
|
||||||
@ -321,7 +425,14 @@ class RunManager(BaseRunManager):
|
|||||||
text: str,
|
text: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when text is received."""
|
"""Run when text is received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The received text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the callback.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_text",
|
"on_text",
|
||||||
@ -341,7 +452,14 @@ class AsyncRunManager(BaseRunManager):
|
|||||||
text: str,
|
text: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when text is received."""
|
"""Run when text is received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The received text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the callback.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_text",
|
"on_text",
|
||||||
@ -361,7 +479,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
token: str,
|
token: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM generates a new token."""
|
"""Run when LLM generates a new token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The new token.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_llm_new_token",
|
"on_llm_new_token",
|
||||||
@ -373,7 +495,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Run when LLM ends running."""
|
"""Run when LLM ends running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (LLMResult): The LLM result.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_llm_end",
|
"on_llm_end",
|
||||||
@ -389,7 +515,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
error: Union[Exception, KeyboardInterrupt],
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_llm_error",
|
"on_llm_error",
|
||||||
@ -409,7 +539,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
token: str,
|
token: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM generates a new token."""
|
"""Run when LLM generates a new token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The new token.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_llm_new_token",
|
"on_llm_new_token",
|
||||||
@ -421,7 +555,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Run when LLM ends running."""
|
"""Run when LLM ends running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (LLMResult): The LLM result.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_llm_end",
|
"on_llm_end",
|
||||||
@ -437,7 +575,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
error: Union[Exception, KeyboardInterrupt],
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_llm_error",
|
"on_llm_error",
|
||||||
@ -453,7 +595,15 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
|||||||
"""Callback manager for chain run."""
|
"""Callback manager for chain run."""
|
||||||
|
|
||||||
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||||
"""Get a child callback manager."""
|
"""Get a child callback manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag (str, optional): The tag for the child callback manager.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManager: The child callback manager.
|
||||||
|
"""
|
||||||
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||||
manager.set_handlers(self.inheritable_handlers)
|
manager.set_handlers(self.inheritable_handlers)
|
||||||
manager.add_tags(self.inheritable_tags)
|
manager.add_tags(self.inheritable_tags)
|
||||||
@ -462,7 +612,11 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
|||||||
return manager
|
return manager
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, Any]): The outputs of the chain.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_chain_end",
|
"on_chain_end",
|
||||||
@ -478,7 +632,11 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
|||||||
error: Union[Exception, KeyboardInterrupt],
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_chain_error",
|
"on_chain_error",
|
||||||
@ -490,7 +648,14 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
"""Run when agent action is received."""
|
"""Run when agent action is received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action (AgentAction): The agent action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the callback.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_agent_action",
|
"on_agent_action",
|
||||||
@ -502,7 +667,14 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||||
"""Run when agent finish is received."""
|
"""Run when agent finish is received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
finish (AgentFinish): The agent finish.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the callback.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_agent_finish",
|
"on_agent_finish",
|
||||||
@ -518,7 +690,15 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
|||||||
"""Async callback manager for chain run."""
|
"""Async callback manager for chain run."""
|
||||||
|
|
||||||
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||||
"""Get a child callback manager."""
|
"""Get a child callback manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag (str, optional): The tag for the child callback manager.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncCallbackManager: The child callback manager.
|
||||||
|
"""
|
||||||
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||||
manager.set_handlers(self.inheritable_handlers)
|
manager.set_handlers(self.inheritable_handlers)
|
||||||
manager.add_tags(self.inheritable_tags)
|
manager.add_tags(self.inheritable_tags)
|
||||||
@ -527,7 +707,11 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
|||||||
return manager
|
return manager
|
||||||
|
|
||||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, Any]): The outputs of the chain.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_chain_end",
|
"on_chain_end",
|
||||||
@ -543,7 +727,11 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
|||||||
error: Union[Exception, KeyboardInterrupt],
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_chain_error",
|
"on_chain_error",
|
||||||
@ -555,7 +743,14 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
"""Run when agent action is received."""
|
"""Run when agent action is received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action (AgentAction): The agent action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the callback.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_agent_action",
|
"on_agent_action",
|
||||||
@ -567,7 +762,14 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||||
"""Run when agent finish is received."""
|
"""Run when agent finish is received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
finish (AgentFinish): The agent finish.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the callback.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_agent_finish",
|
"on_agent_finish",
|
||||||
@ -583,7 +785,15 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
|||||||
"""Callback manager for tool run."""
|
"""Callback manager for tool run."""
|
||||||
|
|
||||||
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||||
"""Get a child callback manager."""
|
"""Get a child callback manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag (str, optional): The tag for the child callback manager.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManager: The child callback manager.
|
||||||
|
"""
|
||||||
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||||
manager.set_handlers(self.inheritable_handlers)
|
manager.set_handlers(self.inheritable_handlers)
|
||||||
manager.add_tags(self.inheritable_tags)
|
manager.add_tags(self.inheritable_tags)
|
||||||
@ -596,7 +806,11 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
|||||||
output: str,
|
output: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output (str): The output of the tool.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_tool_end",
|
"on_tool_end",
|
||||||
@ -612,7 +826,11 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
|||||||
error: Union[Exception, KeyboardInterrupt],
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_tool_error",
|
"on_tool_error",
|
||||||
@ -628,7 +846,15 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
|||||||
"""Async callback manager for tool run."""
|
"""Async callback manager for tool run."""
|
||||||
|
|
||||||
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||||
"""Get a child callback manager."""
|
"""Get a child callback manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag (str, optional): The tag to add to the child
|
||||||
|
callback manager. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncCallbackManager: The child callback manager.
|
||||||
|
"""
|
||||||
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||||
manager.set_handlers(self.inheritable_handlers)
|
manager.set_handlers(self.inheritable_handlers)
|
||||||
manager.add_tags(self.inheritable_tags)
|
manager.add_tags(self.inheritable_tags)
|
||||||
@ -637,7 +863,11 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
|||||||
return manager
|
return manager
|
||||||
|
|
||||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output (str): The output of the tool.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_tool_end",
|
"on_tool_end",
|
||||||
@ -653,7 +883,11 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
|||||||
error: Union[Exception, KeyboardInterrupt],
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_tool_error",
|
"on_tool_error",
|
||||||
@ -674,7 +908,17 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[CallbackManagerForLLMRun]:
|
) -> List[CallbackManagerForLLMRun]:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serialized (Dict[str, Any]): The serialized LLM.
|
||||||
|
prompts (List[str]): The list of prompts.
|
||||||
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[CallbackManagerForLLMRun]: A callback manager for each
|
||||||
|
prompt as an LLM run.
|
||||||
|
"""
|
||||||
managers = []
|
managers = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
run_id_ = uuid4()
|
run_id_ = uuid4()
|
||||||
@ -709,7 +953,17 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[CallbackManagerForLLMRun]:
|
) -> List[CallbackManagerForLLMRun]:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serialized (Dict[str, Any]): The serialized LLM.
|
||||||
|
messages (List[List[BaseMessage]]): The list of messages.
|
||||||
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[CallbackManagerForLLMRun]: A callback manager for each
|
||||||
|
list of messages as an LLM run.
|
||||||
|
"""
|
||||||
|
|
||||||
managers = []
|
managers = []
|
||||||
for message_list in messages:
|
for message_list in messages:
|
||||||
@ -746,7 +1000,16 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
run_id: Optional[UUID] = None,
|
run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> CallbackManagerForChainRun:
|
) -> CallbackManagerForChainRun:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serialized (Dict[str, Any]): The serialized chain.
|
||||||
|
inputs (Dict[str, Any]): The inputs to the chain.
|
||||||
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManagerForChainRun: The callback manager for the chain run.
|
||||||
|
"""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid4()
|
||||||
|
|
||||||
@ -779,7 +1042,17 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> CallbackManagerForToolRun:
|
) -> CallbackManagerForToolRun:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serialized (Dict[str, Any]): The serialized tool.
|
||||||
|
input_str (str): The input to the tool.
|
||||||
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManagerForToolRun: The callback manager for the tool run.
|
||||||
|
"""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid4()
|
||||||
|
|
||||||
@ -813,7 +1086,22 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
inheritable_tags: Optional[List[str]] = None,
|
inheritable_tags: Optional[List[str]] = None,
|
||||||
local_tags: Optional[List[str]] = None,
|
local_tags: Optional[List[str]] = None,
|
||||||
) -> CallbackManager:
|
) -> CallbackManager:
|
||||||
"""Configure the callback manager."""
|
"""Configure the callback manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inheritable_callbacks (Optional[Callbacks], optional): The inheritable
|
||||||
|
callbacks. Defaults to None.
|
||||||
|
local_callbacks (Optional[Callbacks], optional): The local callbacks.
|
||||||
|
Defaults to None.
|
||||||
|
verbose (bool, optional): Whether to enable verbose mode. Defaults to False.
|
||||||
|
inheritable_tags (Optional[List[str]], optional): The inheritable tags.
|
||||||
|
Defaults to None.
|
||||||
|
local_tags (Optional[List[str]], optional): The local tags.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManager: The configured callback manager.
|
||||||
|
"""
|
||||||
return _configure(
|
return _configure(
|
||||||
cls,
|
cls,
|
||||||
inheritable_callbacks,
|
inheritable_callbacks,
|
||||||
@ -838,7 +1126,18 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serialized (Dict[str, Any]): The serialized LLM.
|
||||||
|
prompts (List[str]): The list of prompts.
|
||||||
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[AsyncCallbackManagerForLLMRun]: The list of async
|
||||||
|
callback managers, one for each LLM Run corresponding
|
||||||
|
to each prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
managers = []
|
managers = []
|
||||||
@ -881,6 +1180,18 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
"""Run when LLM starts running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serialized (Dict[str, Any]): The serialized LLM.
|
||||||
|
messages (List[List[BaseMessage]]): The list of messages.
|
||||||
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[AsyncCallbackManagerForLLMRun]: The list of
|
||||||
|
async callback managers, one for each LLM Run
|
||||||
|
corresponding to each inner message list.
|
||||||
|
"""
|
||||||
tasks = []
|
tasks = []
|
||||||
managers = []
|
managers = []
|
||||||
|
|
||||||
@ -922,7 +1233,17 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
run_id: Optional[UUID] = None,
|
run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncCallbackManagerForChainRun:
|
) -> AsyncCallbackManagerForChainRun:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serialized (Dict[str, Any]): The serialized chain.
|
||||||
|
inputs (Dict[str, Any]): The inputs to the chain.
|
||||||
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncCallbackManagerForChainRun: The async callback manager
|
||||||
|
for the chain run.
|
||||||
|
"""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid4()
|
||||||
|
|
||||||
@ -955,7 +1276,19 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncCallbackManagerForToolRun:
|
) -> AsyncCallbackManagerForToolRun:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serialized (Dict[str, Any]): The serialized tool.
|
||||||
|
input_str (str): The input to the tool.
|
||||||
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
parent_run_id (UUID, optional): The ID of the parent run.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncCallbackManagerForToolRun: The async callback manager
|
||||||
|
for the tool run.
|
||||||
|
"""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid4()
|
||||||
|
|
||||||
@ -989,7 +1322,22 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
inheritable_tags: Optional[List[str]] = None,
|
inheritable_tags: Optional[List[str]] = None,
|
||||||
local_tags: Optional[List[str]] = None,
|
local_tags: Optional[List[str]] = None,
|
||||||
) -> AsyncCallbackManager:
|
) -> AsyncCallbackManager:
|
||||||
"""Configure the callback manager."""
|
"""Configure the async callback manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inheritable_callbacks (Optional[Callbacks], optional): The inheritable
|
||||||
|
callbacks. Defaults to None.
|
||||||
|
local_callbacks (Optional[Callbacks], optional): The local callbacks.
|
||||||
|
Defaults to None.
|
||||||
|
verbose (bool, optional): Whether to enable verbose mode. Defaults to False.
|
||||||
|
inheritable_tags (Optional[List[str]], optional): The inheritable tags.
|
||||||
|
Defaults to None.
|
||||||
|
local_tags (Optional[List[str]], optional): The local tags.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncCallbackManager: The configured async callback manager.
|
||||||
|
"""
|
||||||
return _configure(
|
return _configure(
|
||||||
cls,
|
cls,
|
||||||
inheritable_callbacks,
|
inheritable_callbacks,
|
||||||
@ -1004,7 +1352,14 @@ T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
|||||||
|
|
||||||
|
|
||||||
def env_var_is_set(env_var: str) -> bool:
|
def env_var_is_set(env_var: str) -> bool:
|
||||||
"""Check if an environment variable is set."""
|
"""Check if an environment variable is set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_var (str): The name of the environment variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the environment variable is set, False otherwise.
|
||||||
|
"""
|
||||||
return env_var in os.environ and os.environ[env_var] not in (
|
return env_var in os.environ and os.environ[env_var] not in (
|
||||||
"",
|
"",
|
||||||
"0",
|
"0",
|
||||||
@ -1021,7 +1376,22 @@ def _configure(
|
|||||||
inheritable_tags: Optional[List[str]] = None,
|
inheritable_tags: Optional[List[str]] = None,
|
||||||
local_tags: Optional[List[str]] = None,
|
local_tags: Optional[List[str]] = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Configure the callback manager."""
|
"""Configure the callback manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback_manager_cls (Type[T]): The callback manager class.
|
||||||
|
inheritable_callbacks (Optional[Callbacks], optional): The inheritable
|
||||||
|
callbacks. Defaults to None.
|
||||||
|
local_callbacks (Optional[Callbacks], optional): The local callbacks.
|
||||||
|
Defaults to None.
|
||||||
|
verbose (bool, optional): Whether to enable verbose mode. Defaults to False.
|
||||||
|
inheritable_tags (Optional[List[str]], optional): The inheritable tags.
|
||||||
|
Defaults to None.
|
||||||
|
local_tags (Optional[List[str]], optional): The local tags. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
T: The configured callback manager.
|
||||||
|
"""
|
||||||
callback_manager = callback_manager_cls(handlers=[])
|
callback_manager = callback_manager_cls(handlers=[])
|
||||||
if inheritable_callbacks or local_callbacks:
|
if inheritable_callbacks or local_callbacks:
|
||||||
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user