Add support for tags in chain group context manager (#6668)

Lets you specify local and inheritable tags in the group manager.

Also, add more verbose docstrings for our reference docs.
This commit is contained in:
Zander Chase 2023-06-26 10:37:33 -07:00 committed by GitHub
parent d1bcc58beb
commit 9ca3b4645e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -74,7 +74,16 @@ def _get_debug() -> bool:
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get OpenAI callback handler in a context manager."""
"""Get the OpenAI callback handler in a context manager.
which conveniently exposes token and cost information.
Returns:
OpenAICallbackHandler: The OpenAI callback handler.
Example:
>>> with get_openai_callback() as cb:
... # Use the OpenAI callback handler
"""
cb = OpenAICallbackHandler()
openai_callback_var.set(cb)
yield cb
@ -85,7 +94,19 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]:
"""Get Tracer in a context manager."""
"""Get the Deprecated LangChainTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
TracerSessionV1: The LangChainTracer session.
Example:
>>> with tracing_enabled() as session:
... # Use the LangChainTracer session
"""
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
tracing_callback_var.set(cb)
@ -97,7 +118,19 @@ def tracing_enabled(
def wandb_tracing_enabled(
session_name: str = "default",
) -> Generator[None, None, None]:
"""Get WandbTracer in a context manager."""
"""Get the WandbTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
None
Example:
>>> with wandb_tracing_enabled() as session:
... # Use the WandbTracer session
"""
cb = WandbTracer()
wandb_tracing_callback_var.set(cb)
yield None
@ -110,7 +143,21 @@ def tracing_v2_enabled(
*,
example_id: Optional[Union[str, UUID]] = None,
) -> Generator[None, None, None]:
"""Get the experimental tracer handler in a context manager."""
"""Instruct LangChain to log all runs in context to LangSmith.
Args:
project_name (str, optional): The name of the project.
Defaults to "default".
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
Returns:
None
Example:
>>> with tracing_v2_enabled():
... # LangChain code will automatically be traced
"""
# Issue a warning that this is experimental
warnings.warn(
"The tracing v2 API is in development. "
@ -133,14 +180,36 @@ def trace_as_chain_group(
*,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
) -> Generator[CallbackManager, None, None]:
"""Get a callback manager for a chain group in a context manager."""
"""Get a callback manager for a chain group in a context manager.
Useful for grouping different calls together as a single run even if
they aren't composed in a single chain.
Args:
group_name (str): The name of the chain group.
project_name (str, optional): The name of the project.
Defaults to None.
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None.
Returns:
CallbackManager: The callback manager for the chain group.
Example:
>>> with trace_as_chain_group("group_name") as manager:
... # Use the callback manager for the chain group
... llm.predict("Foo", callbacks=manager)
"""
cb = LangChainTracer(
project_name=project_name,
example_id=example_id,
)
cm = CallbackManager.configure(
inheritable_callbacks=[cb],
inheritable_tags=tags,
)
run_manager = cm.on_chain_start({"name": group_name}, {})
@ -154,14 +223,34 @@ async def atrace_as_chain_group(
*,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
) -> AsyncGenerator[AsyncCallbackManager, None]:
"""Get a callback manager for a chain group in a context manager."""
"""Get an async callback manager for a chain group in a context manager.
Useful for grouping different async calls together as a single run even if
they aren't composed in a single chain.
Args:
group_name (str): The name of the chain group.
project_name (str, optional): The name of the project.
Defaults to None.
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None.
Returns:
AsyncCallbackManager: The async callback manager for the chain group.
Example:
>>> async with atrace_as_chain_group("group_name") as manager:
... # Use the async callback manager for the chain group
... await llm.apredict("Foo", callbacks=manager)
"""
cb = LangChainTracer(
project_name=project_name,
example_id=example_id,
)
cm = AsyncCallbackManager.configure(
inheritable_callbacks=[cb],
inheritable_callbacks=[cb], inheritable_tags=tags
)
run_manager = await cm.on_chain_start({"name": group_name}, {})
@ -293,7 +382,18 @@ class BaseRunManager(RunManagerMixin):
tags: List[str],
inheritable_tags: List[str],
) -> None:
"""Initialize run manager."""
"""Initialize the run manager.
Args:
run_id (UUID): The ID of the run.
handlers (List[BaseCallbackHandler]): The list of handlers.
inheritable_handlers (List[BaseCallbackHandler]):
The list of inheritable handlers.
parent_run_id (UUID, optional): The ID of the parent run.
Defaults to None.
tags (List[str]): The list of tags.
inheritable_tags (List[str]): The list of inheritable tags.
"""
self.run_id = run_id
self.handlers = handlers
self.inheritable_handlers = inheritable_handlers
@ -303,7 +403,11 @@ class BaseRunManager(RunManagerMixin):
@classmethod
def get_noop_manager(cls: Type[BRM]) -> BRM:
"""Return a manager that doesn't perform any operations."""
"""Return a manager that doesn't perform any operations.
Returns:
BaseRunManager: The noop manager.
"""
return cls(
run_id=uuid4(),
handlers=[],
@ -321,7 +425,14 @@ class RunManager(BaseRunManager):
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received."""
"""Run when text is received.
Args:
text (str): The received text.
Returns:
Any: The result of the callback.
"""
_handle_event(
self.handlers,
"on_text",
@ -341,7 +452,14 @@ class AsyncRunManager(BaseRunManager):
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received."""
"""Run when text is received.
Args:
text (str): The received text.
Returns:
Any: The result of the callback.
"""
await _ahandle_event(
self.handlers,
"on_text",
@ -361,7 +479,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
token: str,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token."""
"""Run when LLM generates a new token.
Args:
token (str): The new token.
"""
_handle_event(
self.handlers,
"on_llm_new_token",
@ -373,7 +495,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
"""Run when LLM ends running.
Args:
response (LLMResult): The LLM result.
"""
_handle_event(
self.handlers,
"on_llm_end",
@ -389,7 +515,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
"""Run when LLM errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_handle_event(
self.handlers,
"on_llm_error",
@ -409,7 +539,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
token: str,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token."""
"""Run when LLM generates a new token.
Args:
token (str): The new token.
"""
await _ahandle_event(
self.handlers,
"on_llm_new_token",
@ -421,7 +555,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
)
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
"""Run when LLM ends running.
Args:
response (LLMResult): The LLM result.
"""
await _ahandle_event(
self.handlers,
"on_llm_end",
@ -437,7 +575,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
"""Run when LLM errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
await _ahandle_event(
self.handlers,
"on_llm_error",
@ -453,7 +595,15 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
"""Callback manager for chain run."""
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager."""
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
CallbackManager: The child callback manager.
"""
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
@ -462,7 +612,11 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
return manager
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
"""Run when chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
"""
_handle_event(
self.handlers,
"on_chain_end",
@ -478,7 +632,11 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when chain errors."""
"""Run when chain errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_handle_event(
self.handlers,
"on_chain_error",
@ -490,7 +648,14 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received."""
"""Run when agent action is received.
Args:
action (AgentAction): The agent action.
Returns:
Any: The result of the callback.
"""
_handle_event(
self.handlers,
"on_agent_action",
@ -502,7 +667,14 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received."""
"""Run when agent finish is received.
Args:
finish (AgentFinish): The agent finish.
Returns:
Any: The result of the callback.
"""
_handle_event(
self.handlers,
"on_agent_finish",
@ -518,7 +690,15 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
"""Async callback manager for chain run."""
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager."""
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
AsyncCallbackManager: The child callback manager.
"""
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
@ -527,7 +707,11 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
return manager
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
"""Run when chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
"""
await _ahandle_event(
self.handlers,
"on_chain_end",
@ -543,7 +727,11 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when chain errors."""
"""Run when chain errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
await _ahandle_event(
self.handlers,
"on_chain_error",
@ -555,7 +743,14 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
)
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received."""
"""Run when agent action is received.
Args:
action (AgentAction): The agent action.
Returns:
Any: The result of the callback.
"""
await _ahandle_event(
self.handlers,
"on_agent_action",
@ -567,7 +762,14 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
)
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received."""
"""Run when agent finish is received.
Args:
finish (AgentFinish): The agent finish.
Returns:
Any: The result of the callback.
"""
await _ahandle_event(
self.handlers,
"on_agent_finish",
@ -583,7 +785,15 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
"""Callback manager for tool run."""
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager."""
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
CallbackManager: The child callback manager.
"""
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
@ -596,7 +806,11 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
output: str,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
"""Run when tool ends running.
Args:
output (str): The output of the tool.
"""
_handle_event(
self.handlers,
"on_tool_end",
@ -612,7 +826,11 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when tool errors."""
"""Run when tool errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_handle_event(
self.handlers,
"on_tool_error",
@ -628,7 +846,15 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
"""Async callback manager for tool run."""
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager."""
"""Get a child callback manager.
Args:
tag (str, optional): The tag to add to the child
callback manager. Defaults to None.
Returns:
AsyncCallbackManager: The child callback manager.
"""
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
@ -637,7 +863,11 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
return manager
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
"""Run when tool ends running.
Args:
output (str): The output of the tool.
"""
await _ahandle_event(
self.handlers,
"on_tool_end",
@ -653,7 +883,11 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when tool errors."""
"""Run when tool errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
await _ahandle_event(
self.handlers,
"on_tool_error",
@ -674,7 +908,17 @@ class CallbackManager(BaseCallbackManager):
prompts: List[str],
**kwargs: Any,
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
"""Run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The list of prompts.
run_id (UUID, optional): The ID of the run. Defaults to None.
Returns:
List[CallbackManagerForLLMRun]: A callback manager for each
prompt as an LLM run.
"""
managers = []
for prompt in prompts:
run_id_ = uuid4()
@ -709,7 +953,17 @@ class CallbackManager(BaseCallbackManager):
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
"""Run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
messages (List[List[BaseMessage]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None.
Returns:
List[CallbackManagerForLLMRun]: A callback manager for each
list of messages as an LLM run.
"""
managers = []
for message_list in messages:
@ -746,7 +1000,16 @@ class CallbackManager(BaseCallbackManager):
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForChainRun:
"""Run when chain starts running."""
"""Run when chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
run_id (UUID, optional): The ID of the run. Defaults to None.
Returns:
CallbackManagerForChainRun: The callback manager for the chain run.
"""
if run_id is None:
run_id = uuid4()
@ -779,7 +1042,17 @@ class CallbackManager(BaseCallbackManager):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForToolRun:
"""Run when tool starts running."""
"""Run when tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input to the tool.
run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
Returns:
CallbackManagerForToolRun: The callback manager for the tool run.
"""
if run_id is None:
run_id = uuid4()
@ -813,7 +1086,22 @@ class CallbackManager(BaseCallbackManager):
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> CallbackManager:
"""Configure the callback manager."""
"""Configure the callback manager.
Args:
inheritable_callbacks (Optional[Callbacks], optional): The inheritable
callbacks. Defaults to None.
local_callbacks (Optional[Callbacks], optional): The local callbacks.
Defaults to None.
verbose (bool, optional): Whether to enable verbose mode. Defaults to False.
inheritable_tags (Optional[List[str]], optional): The inheritable tags.
Defaults to None.
local_tags (Optional[List[str]], optional): The local tags.
Defaults to None.
Returns:
CallbackManager: The configured callback manager.
"""
return _configure(
cls,
inheritable_callbacks,
@ -838,7 +1126,18 @@ class AsyncCallbackManager(BaseCallbackManager):
prompts: List[str],
**kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running."""
"""Run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The list of prompts.
run_id (UUID, optional): The ID of the run. Defaults to None.
Returns:
List[AsyncCallbackManagerForLLMRun]: The list of async
callback managers, one for each LLM Run corresponding
to each prompt.
"""
tasks = []
managers = []
@ -881,6 +1180,18 @@ class AsyncCallbackManager(BaseCallbackManager):
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> Any:
"""Run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
messages (List[List[BaseMessage]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None.
Returns:
List[AsyncCallbackManagerForLLMRun]: The list of
async callback managers, one for each LLM Run
corresponding to each inner message list.
"""
tasks = []
managers = []
@ -922,7 +1233,17 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForChainRun:
"""Run when chain starts running."""
"""Run when chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
run_id (UUID, optional): The ID of the run. Defaults to None.
Returns:
AsyncCallbackManagerForChainRun: The async callback manager
for the chain run.
"""
if run_id is None:
run_id = uuid4()
@ -955,7 +1276,19 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForToolRun:
"""Run when tool starts running."""
"""Run when tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input to the tool.
run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run.
Defaults to None.
Returns:
AsyncCallbackManagerForToolRun: The async callback manager
for the tool run.
"""
if run_id is None:
run_id = uuid4()
@ -989,7 +1322,22 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> AsyncCallbackManager:
"""Configure the callback manager."""
"""Configure the async callback manager.
Args:
inheritable_callbacks (Optional[Callbacks], optional): The inheritable
callbacks. Defaults to None.
local_callbacks (Optional[Callbacks], optional): The local callbacks.
Defaults to None.
verbose (bool, optional): Whether to enable verbose mode. Defaults to False.
inheritable_tags (Optional[List[str]], optional): The inheritable tags.
Defaults to None.
local_tags (Optional[List[str]], optional): The local tags.
Defaults to None.
Returns:
AsyncCallbackManager: The configured async callback manager.
"""
return _configure(
cls,
inheritable_callbacks,
@ -1004,7 +1352,14 @@ T = TypeVar("T", CallbackManager, AsyncCallbackManager)
def env_var_is_set(env_var: str) -> bool:
"""Check if an environment variable is set."""
"""Check if an environment variable is set.
Args:
env_var (str): The name of the environment variable.
Returns:
bool: True if the environment variable is set, False otherwise.
"""
return env_var in os.environ and os.environ[env_var] not in (
"",
"0",
@ -1021,7 +1376,22 @@ def _configure(
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> T:
"""Configure the callback manager."""
"""Configure the callback manager.
Args:
callback_manager_cls (Type[T]): The callback manager class.
inheritable_callbacks (Optional[Callbacks], optional): The inheritable
callbacks. Defaults to None.
local_callbacks (Optional[Callbacks], optional): The local callbacks.
Defaults to None.
verbose (bool, optional): Whether to enable verbose mode. Defaults to False.
inheritable_tags (Optional[List[str]], optional): The inheritable tags.
Defaults to None.
local_tags (Optional[List[str]], optional): The local tags. Defaults to None.
Returns:
T: The configured callback manager.
"""
callback_manager = callback_manager_cls(handlers=[])
if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: