mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 01:48:57 +00:00
Add support for showing IO to chain group (#10510)
As well as error propagation
This commit is contained in:
parent
2c957de2fc
commit
c5078fb13c
@ -213,17 +213,20 @@ def trace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Generator[CallbackManager, None, None]:
|
||||
) -> Generator[CallbackManagerForChainGroup, None, None]:
|
||||
"""Get a callback manager for a chain group in a context manager.
|
||||
Useful for grouping different calls together as a single run even if
|
||||
they aren't composed in a single chain.
|
||||
|
||||
Args:
|
||||
group_name (str): The name of the chain group.
|
||||
callback_manager (CallbackManager, optional): The callback manager to use.
|
||||
inputs (Dict[str, Any], optional): The inputs to the chain group.
|
||||
project_name (str, optional): The name of the project.
|
||||
Defaults to None.
|
||||
example_id (str or UUID, optional): The ID of the example.
|
||||
@ -233,13 +236,17 @@ def trace_as_chain_group(
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The callback manager for the chain group.
|
||||
CallbackManagerForChainGroup: The callback manager for the chain group.
|
||||
|
||||
Example:
|
||||
>>> with trace_as_chain_group("group_name") as manager:
|
||||
... # Use the callback manager for the chain group
|
||||
... llm.predict("Foo", callbacks=manager)
|
||||
"""
|
||||
.. code-block:: python
|
||||
|
||||
llm_input = "Foo"
|
||||
with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
|
||||
# Use the callback manager for the chain group
|
||||
res = llm.predict(llm_input, callbacks=manager)
|
||||
manager.on_chain_end({"output": res})
|
||||
""" # noqa: E501
|
||||
cb = cast(
|
||||
Callbacks,
|
||||
[
|
||||
@ -256,9 +263,27 @@ def trace_as_chain_group(
|
||||
inheritable_tags=tags,
|
||||
)
|
||||
|
||||
run_manager = cm.on_chain_start({"name": group_name}, {}, run_id=run_id)
|
||||
yield run_manager.get_child()
|
||||
run_manager.on_chain_end({})
|
||||
run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id)
|
||||
child_cm = run_manager.get_child()
|
||||
group_cm = CallbackManagerForChainGroup(
|
||||
child_cm.handlers,
|
||||
child_cm.inheritable_handlers,
|
||||
child_cm.parent_run_id,
|
||||
parent_run_manager=run_manager,
|
||||
tags=child_cm.tags,
|
||||
inheritable_tags=child_cm.inheritable_tags,
|
||||
metadata=child_cm.metadata,
|
||||
inheritable_metadata=child_cm.inheritable_metadata,
|
||||
)
|
||||
try:
|
||||
yield group_cm
|
||||
except Exception as e:
|
||||
if not group_cm.ended:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
if not group_cm.ended:
|
||||
run_manager.on_chain_end({})
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -266,17 +291,20 @@ async def atrace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[AsyncCallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> AsyncGenerator[AsyncCallbackManager, None]:
|
||||
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
|
||||
"""Get an async callback manager for a chain group in a context manager.
|
||||
Useful for grouping different async calls together as a single run even if
|
||||
they aren't composed in a single chain.
|
||||
|
||||
Args:
|
||||
group_name (str): The name of the chain group.
|
||||
callback_manager (AsyncCallbackManager, optional): The async callback manager to use,
|
||||
which manages tracing and other callback behavior.
|
||||
project_name (str, optional): The name of the project.
|
||||
Defaults to None.
|
||||
example_id (str or UUID, optional): The ID of the example.
|
||||
@ -288,10 +316,14 @@ async def atrace_as_chain_group(
|
||||
AsyncCallbackManager: The async callback manager for the chain group.
|
||||
|
||||
Example:
|
||||
>>> async with atrace_as_chain_group("group_name") as manager:
|
||||
... # Use the async callback manager for the chain group
|
||||
... await llm.apredict("Foo", callbacks=manager)
|
||||
"""
|
||||
.. code-block:: python
|
||||
|
||||
llm_input = "Foo"
|
||||
async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
|
||||
# Use the async callback manager for the chain group
|
||||
res = await llm.apredict(llm_input, callbacks=manager)
|
||||
await manager.on_chain_end({"output": res})
|
||||
""" # noqa: E501
|
||||
cb = cast(
|
||||
Callbacks,
|
||||
[
|
||||
@ -305,11 +337,29 @@ async def atrace_as_chain_group(
|
||||
)
|
||||
cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags)
|
||||
|
||||
run_manager = await cm.on_chain_start({"name": group_name}, {}, run_id=run_id)
|
||||
run_manager = await cm.on_chain_start(
|
||||
{"name": group_name}, inputs or {}, run_id=run_id
|
||||
)
|
||||
child_cm = run_manager.get_child()
|
||||
group_cm = AsyncCallbackManagerForChainGroup(
|
||||
child_cm.handlers,
|
||||
child_cm.inheritable_handlers,
|
||||
child_cm.parent_run_id,
|
||||
parent_run_manager=run_manager,
|
||||
tags=child_cm.tags,
|
||||
inheritable_tags=child_cm.inheritable_tags,
|
||||
metadata=child_cm.metadata,
|
||||
inheritable_metadata=child_cm.inheritable_metadata,
|
||||
)
|
||||
try:
|
||||
yield run_manager.get_child()
|
||||
finally:
|
||||
await run_manager.on_chain_end({})
|
||||
yield group_cm
|
||||
except Exception as e:
|
||||
if not group_cm.ended:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
if not group_cm.ended:
|
||||
await run_manager.on_chain_end({})
|
||||
|
||||
|
||||
def _handle_event(
|
||||
@ -1342,6 +1392,48 @@ class CallbackManager(BaseCallbackManager):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForChainGroup(CallbackManager):
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler] | None = None,
|
||||
parent_run_id: UUID | None = None,
|
||||
*,
|
||||
parent_run_manager: CallbackManagerForChainRun,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
handlers,
|
||||
inheritable_handlers,
|
||||
parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.parent_run_manager = parent_run_manager
|
||||
self.ended = False
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
Args:
|
||||
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
|
||||
"""
|
||||
self.ended = True
|
||||
return self.parent_run_manager.on_chain_end(outputs, **kwargs)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
"""
|
||||
self.ended = True
|
||||
return self.parent_run_manager.on_chain_error(error, **kwargs)
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that handles callbacks from LangChain."""
|
||||
|
||||
@ -1634,6 +1726,50 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler] | None = None,
|
||||
parent_run_id: UUID | None = None,
|
||||
*,
|
||||
parent_run_manager: AsyncCallbackManagerForChainRun,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
handlers,
|
||||
inheritable_handlers,
|
||||
parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.parent_run_manager = parent_run_manager
|
||||
self.ended = False
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
Args:
|
||||
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
|
||||
"""
|
||||
self.ended = True
|
||||
await self.parent_run_manager.on_chain_end(outputs, **kwargs)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
"""
|
||||
self.ended = True
|
||||
await self.parent_run_manager.on_chain_error(error, **kwargs)
|
||||
|
||||
|
||||
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
||||
|
||||
|
||||
|
@ -222,13 +222,15 @@ def test_trace_as_group() -> None:
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
with trace_as_chain_group("my_group") as group_manager:
|
||||
with trace_as_chain_group("my_group", inputs={"input": "cars"}) as group_manager:
|
||||
chain.run(product="cars", callbacks=group_manager)
|
||||
chain.run(product="computers", callbacks=group_manager)
|
||||
chain.run(product="toys", callbacks=group_manager)
|
||||
final_res = chain.run(product="toys", callbacks=group_manager)
|
||||
group_manager.on_chain_end({"output": final_res})
|
||||
|
||||
with trace_as_chain_group("my_group_2") as group_manager:
|
||||
chain.run(product="toys", callbacks=group_manager)
|
||||
with trace_as_chain_group("my_group_2", inputs={"input": "toys"}) as group_manager:
|
||||
final_res = chain.run(product="toys", callbacks=group_manager)
|
||||
group_manager.on_chain_end({"output": final_res})
|
||||
|
||||
|
||||
def test_trace_as_group_with_env_set() -> None:
|
||||
@ -239,13 +241,19 @@ def test_trace_as_group_with_env_set() -> None:
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
with trace_as_chain_group("my_group") as group_manager:
|
||||
with trace_as_chain_group(
|
||||
"my_group_env_set", inputs={"input": "cars"}
|
||||
) as group_manager:
|
||||
chain.run(product="cars", callbacks=group_manager)
|
||||
chain.run(product="computers", callbacks=group_manager)
|
||||
chain.run(product="toys", callbacks=group_manager)
|
||||
final_res = chain.run(product="toys", callbacks=group_manager)
|
||||
group_manager.on_chain_end({"output": final_res})
|
||||
|
||||
with trace_as_chain_group("my_group_2") as group_manager:
|
||||
chain.run(product="toys", callbacks=group_manager)
|
||||
with trace_as_chain_group(
|
||||
"my_group_2_env_set", inputs={"input": "toys"}
|
||||
) as group_manager:
|
||||
final_res = chain.run(product="toys", callbacks=group_manager)
|
||||
group_manager.on_chain_end({"output": final_res})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -256,16 +264,19 @@ async def test_trace_as_group_async() -> None:
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
async with atrace_as_chain_group("my_group") as group_manager:
|
||||
async with atrace_as_chain_group("my_async_group") as group_manager:
|
||||
await chain.arun(product="cars", callbacks=group_manager)
|
||||
await chain.arun(product="computers", callbacks=group_manager)
|
||||
await chain.arun(product="toys", callbacks=group_manager)
|
||||
|
||||
async with atrace_as_chain_group("my_group_2") as group_manager:
|
||||
await asyncio.gather(
|
||||
async with atrace_as_chain_group(
|
||||
"my_async_group_2", inputs={"input": "toys"}
|
||||
) as group_manager:
|
||||
res = await asyncio.gather(
|
||||
*[
|
||||
chain.arun(product="toys", callbacks=group_manager),
|
||||
chain.arun(product="computers", callbacks=group_manager),
|
||||
chain.arun(product="cars", callbacks=group_manager),
|
||||
]
|
||||
)
|
||||
await group_manager.on_chain_end({"output": res})
|
||||
|
Loading…
Reference in New Issue
Block a user