From 9fa07076da34b16ceecccee0f92dfe7c67201d76 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Wed, 7 Feb 2024 09:42:44 -0800 Subject: [PATCH] Add trace_as_chain_group metadata (#17187) --- libs/core/langchain_core/callbacks/manager.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index e216f479832..b1f103871f1 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -67,6 +67,7 @@ def trace_as_chain_group( example_id: Optional[Union[str, UUID]] = None, run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> Generator[CallbackManagerForChainGroup, None, None]: """Get a callback manager for a chain group in a context manager. Useful for grouping different calls together as a single run even if @@ -83,6 +84,8 @@ def trace_as_chain_group( run_id (UUID, optional): The ID of the run. tags (List[str], optional): The inheritable tags to apply to all runs. Defaults to None. + metadata (Dict[str, Any], optional): The metadata to apply to all runs. + Defaults to None. Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. @@ -95,7 +98,7 @@ def trace_as_chain_group( llm_input = "Foo" with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: # Use the callback manager for the chain group - res = llm.predict(llm_input, callbacks=manager) + res = llm.invoke(llm_input, {"callbacks": manager}) manager.on_chain_end({"output": res}) """ # noqa: E501 from langchain_core.tracers.context import _get_trace_callbacks @@ -106,6 +109,7 @@ def trace_as_chain_group( cm = CallbackManager.configure( inheritable_callbacks=cb, inheritable_tags=tags, + inheritable_metadata=metadata, ) run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id) @@ -141,6 +145,7 @@ async def atrace_as_chain_group( example_id: Optional[Union[str, UUID]] = None, run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: """Get an async callback manager for a chain group in a context manager. Useful for grouping different async calls together as a single run even if @@ -157,6 +162,8 @@ async def atrace_as_chain_group( run_id (UUID, optional): The ID of the run. tags (List[str], optional): The inheritable tags to apply to all runs. Defaults to None. + metadata (Dict[str, Any], optional): The metadata to apply to all runs. + Defaults to None. Returns: AsyncCallbackManager: The async callback manager for the chain group. @@ -168,7 +175,7 @@ async def atrace_as_chain_group( llm_input = "Foo" async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: # Use the async callback manager for the chain group - res = await llm.apredict(llm_input, callbacks=manager) + res = await llm.ainvoke(llm_input, {"callbacks": manager}) await manager.on_chain_end({"output": res}) """ # noqa: E501 from langchain_core.tracers.context import _get_trace_callbacks @@ -176,7 +183,9 @@ async def atrace_as_chain_group( cb = _get_trace_callbacks( project_name, example_id, callback_manager=callback_manager ) - cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) + cm = AsyncCallbackManager.configure( + inheritable_callbacks=cb, inheritable_tags=tags, inheritable_metadata=metadata + ) run_manager = await cm.on_chain_start( {"name": group_name}, inputs or {}, run_id=run_id