Add support for showing IO to chain group (#10510)

As well as error propagation
This commit is contained in:
William FH 2023-09-17 00:47:51 -07:00 committed by GitHub
parent 2c957de2fc
commit c5078fb13c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 176 additions and 29 deletions

View File

@ -213,17 +213,20 @@ def trace_as_chain_group(
group_name: str, group_name: str,
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
*, *,
inputs: Optional[Dict[str, Any]] = None,
project_name: Optional[str] = None, project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
) -> Generator[CallbackManager, None, None]: ) -> Generator[CallbackManagerForChainGroup, None, None]:
"""Get a callback manager for a chain group in a context manager. """Get a callback manager for a chain group in a context manager.
Useful for grouping different calls together as a single run even if Useful for grouping different calls together as a single run even if
they aren't composed in a single chain. they aren't composed in a single chain.
Args: Args:
group_name (str): The name of the chain group. group_name (str): The name of the chain group.
callback_manager (CallbackManager, optional): The callback manager to use.
inputs (Dict[str, Any], optional): The inputs to the chain group.
project_name (str, optional): The name of the project. project_name (str, optional): The name of the project.
Defaults to None. Defaults to None.
example_id (str or UUID, optional): The ID of the example. example_id (str or UUID, optional): The ID of the example.
@ -233,13 +236,17 @@ def trace_as_chain_group(
Defaults to None. Defaults to None.
Returns: Returns:
CallbackManager: The callback manager for the chain group. CallbackManagerForChainGroup: The callback manager for the chain group.
Example: Example:
>>> with trace_as_chain_group("group_name") as manager: .. code-block:: python
... # Use the callback manager for the chain group
... llm.predict("Foo", callbacks=manager) llm_input = "Foo"
""" with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
# Use the callback manager for the chain group
res = llm.predict(llm_input, callbacks=manager)
manager.on_chain_end({"output": res})
""" # noqa: E501
cb = cast( cb = cast(
Callbacks, Callbacks,
[ [
@ -256,9 +263,27 @@ def trace_as_chain_group(
inheritable_tags=tags, inheritable_tags=tags,
) )
run_manager = cm.on_chain_start({"name": group_name}, {}, run_id=run_id) run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id)
yield run_manager.get_child() child_cm = run_manager.get_child()
run_manager.on_chain_end({}) group_cm = CallbackManagerForChainGroup(
child_cm.handlers,
child_cm.inheritable_handlers,
child_cm.parent_run_id,
parent_run_manager=run_manager,
tags=child_cm.tags,
inheritable_tags=child_cm.inheritable_tags,
metadata=child_cm.metadata,
inheritable_metadata=child_cm.inheritable_metadata,
)
try:
yield group_cm
except Exception as e:
if not group_cm.ended:
run_manager.on_chain_error(e)
raise e
else:
if not group_cm.ended:
run_manager.on_chain_end({})
@asynccontextmanager @asynccontextmanager
@ -266,17 +291,20 @@ async def atrace_as_chain_group(
group_name: str, group_name: str,
callback_manager: Optional[AsyncCallbackManager] = None, callback_manager: Optional[AsyncCallbackManager] = None,
*, *,
inputs: Optional[Dict[str, Any]] = None,
project_name: Optional[str] = None, project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
) -> AsyncGenerator[AsyncCallbackManager, None]: ) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
"""Get an async callback manager for a chain group in a context manager. """Get an async callback manager for a chain group in a context manager.
Useful for grouping different async calls together as a single run even if Useful for grouping different async calls together as a single run even if
they aren't composed in a single chain. they aren't composed in a single chain.
Args: Args:
group_name (str): The name of the chain group. group_name (str): The name of the chain group.
callback_manager (AsyncCallbackManager, optional): The async callback manager to use,
which manages tracing and other callback behavior.
project_name (str, optional): The name of the project. project_name (str, optional): The name of the project.
Defaults to None. Defaults to None.
example_id (str or UUID, optional): The ID of the example. example_id (str or UUID, optional): The ID of the example.
@ -288,10 +316,14 @@ async def atrace_as_chain_group(
AsyncCallbackManager: The async callback manager for the chain group. AsyncCallbackManager: The async callback manager for the chain group.
Example: Example:
>>> async with atrace_as_chain_group("group_name") as manager: .. code-block:: python
... # Use the async callback manager for the chain group
... await llm.apredict("Foo", callbacks=manager) llm_input = "Foo"
""" async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
# Use the async callback manager for the chain group
res = await llm.apredict(llm_input, callbacks=manager)
await manager.on_chain_end({"output": res})
""" # noqa: E501
cb = cast( cb = cast(
Callbacks, Callbacks,
[ [
@ -305,11 +337,29 @@ async def atrace_as_chain_group(
) )
cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags)
run_manager = await cm.on_chain_start({"name": group_name}, {}, run_id=run_id) run_manager = await cm.on_chain_start(
{"name": group_name}, inputs or {}, run_id=run_id
)
child_cm = run_manager.get_child()
group_cm = AsyncCallbackManagerForChainGroup(
child_cm.handlers,
child_cm.inheritable_handlers,
child_cm.parent_run_id,
parent_run_manager=run_manager,
tags=child_cm.tags,
inheritable_tags=child_cm.inheritable_tags,
metadata=child_cm.metadata,
inheritable_metadata=child_cm.inheritable_metadata,
)
try: try:
yield run_manager.get_child() yield group_cm
finally: except Exception as e:
await run_manager.on_chain_end({}) if not group_cm.ended:
await run_manager.on_chain_error(e)
raise e
else:
if not group_cm.ended:
await run_manager.on_chain_end({})
def _handle_event( def _handle_event(
@ -1342,6 +1392,48 @@ class CallbackManager(BaseCallbackManager):
) )
class CallbackManagerForChainGroup(CallbackManager):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler] | None = None,
parent_run_id: UUID | None = None,
*,
parent_run_manager: CallbackManagerForChainRun,
**kwargs: Any,
) -> None:
super().__init__(
handlers,
inheritable_handlers,
parent_run_id,
**kwargs,
)
self.parent_run_manager = parent_run_manager
self.ended = False
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when traced chain group ends.
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
"""
self.ended = True
return self.parent_run_manager.on_chain_end(outputs, **kwargs)
def on_chain_error(
self,
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when chain errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
self.ended = True
return self.parent_run_manager.on_chain_error(error, **kwargs)
class AsyncCallbackManager(BaseCallbackManager): class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that handles callbacks from LangChain.""" """Async callback manager that handles callbacks from LangChain."""
@ -1634,6 +1726,50 @@ class AsyncCallbackManager(BaseCallbackManager):
) )
class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler] | None = None,
parent_run_id: UUID | None = None,
*,
parent_run_manager: AsyncCallbackManagerForChainRun,
**kwargs: Any,
) -> None:
super().__init__(
handlers,
inheritable_handlers,
parent_run_id,
**kwargs,
)
self.parent_run_manager = parent_run_manager
self.ended = False
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None:
"""Run when traced chain group ends.
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
"""
self.ended = True
await self.parent_run_manager.on_chain_end(outputs, **kwargs)
async def on_chain_error(
self,
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when chain errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
self.ended = True
await self.parent_run_manager.on_chain_error(error, **kwargs)
T = TypeVar("T", CallbackManager, AsyncCallbackManager) T = TypeVar("T", CallbackManager, AsyncCallbackManager)

View File

@ -222,13 +222,15 @@ def test_trace_as_group() -> None:
template="What is a good name for a company that makes {product}?", template="What is a good name for a company that makes {product}?",
) )
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
with trace_as_chain_group("my_group") as group_manager: with trace_as_chain_group("my_group", inputs={"input": "cars"}) as group_manager:
chain.run(product="cars", callbacks=group_manager) chain.run(product="cars", callbacks=group_manager)
chain.run(product="computers", callbacks=group_manager) chain.run(product="computers", callbacks=group_manager)
chain.run(product="toys", callbacks=group_manager) final_res = chain.run(product="toys", callbacks=group_manager)
group_manager.on_chain_end({"output": final_res})
with trace_as_chain_group("my_group_2") as group_manager: with trace_as_chain_group("my_group_2", inputs={"input": "toys"}) as group_manager:
chain.run(product="toys", callbacks=group_manager) final_res = chain.run(product="toys", callbacks=group_manager)
group_manager.on_chain_end({"output": final_res})
def test_trace_as_group_with_env_set() -> None: def test_trace_as_group_with_env_set() -> None:
@ -239,13 +241,19 @@ def test_trace_as_group_with_env_set() -> None:
template="What is a good name for a company that makes {product}?", template="What is a good name for a company that makes {product}?",
) )
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
with trace_as_chain_group("my_group") as group_manager: with trace_as_chain_group(
"my_group_env_set", inputs={"input": "cars"}
) as group_manager:
chain.run(product="cars", callbacks=group_manager) chain.run(product="cars", callbacks=group_manager)
chain.run(product="computers", callbacks=group_manager) chain.run(product="computers", callbacks=group_manager)
chain.run(product="toys", callbacks=group_manager) final_res = chain.run(product="toys", callbacks=group_manager)
group_manager.on_chain_end({"output": final_res})
with trace_as_chain_group("my_group_2") as group_manager: with trace_as_chain_group(
chain.run(product="toys", callbacks=group_manager) "my_group_2_env_set", inputs={"input": "toys"}
) as group_manager:
final_res = chain.run(product="toys", callbacks=group_manager)
group_manager.on_chain_end({"output": final_res})
@pytest.mark.asyncio @pytest.mark.asyncio
@ -256,16 +264,19 @@ async def test_trace_as_group_async() -> None:
template="What is a good name for a company that makes {product}?", template="What is a good name for a company that makes {product}?",
) )
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
async with atrace_as_chain_group("my_group") as group_manager: async with atrace_as_chain_group("my_async_group") as group_manager:
await chain.arun(product="cars", callbacks=group_manager) await chain.arun(product="cars", callbacks=group_manager)
await chain.arun(product="computers", callbacks=group_manager) await chain.arun(product="computers", callbacks=group_manager)
await chain.arun(product="toys", callbacks=group_manager) await chain.arun(product="toys", callbacks=group_manager)
async with atrace_as_chain_group("my_group_2") as group_manager: async with atrace_as_chain_group(
await asyncio.gather( "my_async_group_2", inputs={"input": "toys"}
) as group_manager:
res = await asyncio.gather(
*[ *[
chain.arun(product="toys", callbacks=group_manager), chain.arun(product="toys", callbacks=group_manager),
chain.arun(product="computers", callbacks=group_manager), chain.arun(product="computers", callbacks=group_manager),
chain.arun(product="cars", callbacks=group_manager), chain.arun(product="cars", callbacks=group_manager),
] ]
) )
await group_manager.on_chain_end({"output": res})