Add trace_as_chain_group metadata (#17187)

This commit is contained in:
William FH 2024-02-07 09:42:44 -08:00 committed by GitHub
parent 5ceaf784f3
commit 9fa07076da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,6 +67,7 @@ def trace_as_chain_group(
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Generator[CallbackManagerForChainGroup, None, None]: ) -> Generator[CallbackManagerForChainGroup, None, None]:
"""Get a callback manager for a chain group in a context manager. """Get a callback manager for a chain group in a context manager.
Useful for grouping different calls together as a single run even if Useful for grouping different calls together as a single run even if
@ -83,6 +84,8 @@ def trace_as_chain_group(
run_id (UUID, optional): The ID of the run. run_id (UUID, optional): The ID of the run.
tags (List[str], optional): The inheritable tags to apply to all runs. tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None. Defaults to None.
metadata (Dict[str, Any], optional): The metadata to apply to all runs.
Defaults to None.
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
@ -95,7 +98,7 @@ def trace_as_chain_group(
llm_input = "Foo" llm_input = "Foo"
with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
# Use the callback manager for the chain group # Use the callback manager for the chain group
res = llm.predict(llm_input, callbacks=manager) res = llm.invoke(llm_input, {"callbacks": manager})
manager.on_chain_end({"output": res}) manager.on_chain_end({"output": res})
""" # noqa: E501 """ # noqa: E501
from langchain_core.tracers.context import _get_trace_callbacks from langchain_core.tracers.context import _get_trace_callbacks
@ -106,6 +109,7 @@ def trace_as_chain_group(
cm = CallbackManager.configure( cm = CallbackManager.configure(
inheritable_callbacks=cb, inheritable_callbacks=cb,
inheritable_tags=tags, inheritable_tags=tags,
inheritable_metadata=metadata,
) )
run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id) run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id)
@ -141,6 +145,7 @@ async def atrace_as_chain_group(
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: ) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
"""Get an async callback manager for a chain group in a context manager. """Get an async callback manager for a chain group in a context manager.
Useful for grouping different async calls together as a single run even if Useful for grouping different async calls together as a single run even if
@ -157,6 +162,8 @@ async def atrace_as_chain_group(
run_id (UUID, optional): The ID of the run. run_id (UUID, optional): The ID of the run.
tags (List[str], optional): The inheritable tags to apply to all runs. tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None. Defaults to None.
metadata (Dict[str, Any], optional): The metadata to apply to all runs.
Defaults to None.
Returns: Returns:
AsyncCallbackManager: The async callback manager for the chain group. AsyncCallbackManager: The async callback manager for the chain group.
@ -168,7 +175,7 @@ async def atrace_as_chain_group(
llm_input = "Foo" llm_input = "Foo"
async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
# Use the async callback manager for the chain group # Use the async callback manager for the chain group
res = await llm.apredict(llm_input, callbacks=manager) res = await llm.ainvoke(llm_input, {"callbacks": manager})
await manager.on_chain_end({"output": res}) await manager.on_chain_end({"output": res})
""" # noqa: E501 """ # noqa: E501
from langchain_core.tracers.context import _get_trace_callbacks from langchain_core.tracers.context import _get_trace_callbacks
@ -176,7 +183,9 @@ async def atrace_as_chain_group(
cb = _get_trace_callbacks( cb = _get_trace_callbacks(
project_name, example_id, callback_manager=callback_manager project_name, example_id, callback_manager=callback_manager
) )
cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) cm = AsyncCallbackManager.configure(
inheritable_callbacks=cb, inheritable_tags=tags, inheritable_metadata=metadata
)
run_manager = await cm.on_chain_start( run_manager = await cm.on_chain_start(
{"name": group_name}, inputs or {}, run_id=run_id {"name": group_name}, inputs or {}, run_id=run_id