Compare commits

...

2 Commits

Author SHA1 Message Date
William Fu-Hinthorn
f3168d4d70 Add support for invoking with a parent 2024-07-16 16:42:03 -07:00
William Fu-Hinthorn
6d59f2e069 Parent 2024-07-16 15:32:18 -07:00
7 changed files with 69 additions and 7 deletions

View File

@@ -1545,6 +1545,8 @@ class CallbackManager(BaseCallbackManager):
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
*,
parent: Optional[str] = None,
) -> CallbackManager:
"""Configure the callback manager.
@@ -1562,6 +1564,8 @@ class CallbackManager(BaseCallbackManager):
metadata. Defaults to None.
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
Defaults to None.
parent (Optional[str], optional): The parent run's dotted order.
Defaults to None.
Returns:
CallbackManager: The configured callback manager.
@@ -1575,6 +1579,7 @@ class CallbackManager(BaseCallbackManager):
local_tags,
inheritable_metadata,
local_metadata,
parent=parent,
)
@@ -1972,6 +1977,8 @@ class AsyncCallbackManager(BaseCallbackManager):
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
*,
parent: Optional[str] = None,
) -> AsyncCallbackManager:
"""Configure the async callback manager.
@@ -1989,6 +1996,8 @@ class AsyncCallbackManager(BaseCallbackManager):
metadata. Defaults to None.
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
Defaults to None.
parent (Optional[str], optional): The parent run's dotted order.
Defaults to None.
Returns:
AsyncCallbackManager: The configured async callback manager.
@@ -2002,6 +2011,7 @@ class AsyncCallbackManager(BaseCallbackManager):
local_tags,
inheritable_metadata,
local_metadata,
parent,
)
@@ -2092,6 +2102,7 @@ def _configure(
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
parent: Optional[str] = None,
) -> T:
"""Configure the callback manager.
@@ -2113,6 +2124,8 @@ def _configure(
Returns:
T: The configured callback manager.
"""
from langsmith.run_trees import RunTree
from langchain_core.tracers.context import (
_configure_hooks,
_get_tracer_project,
@@ -2121,7 +2134,13 @@ def _configure(
)
run_tree = get_run_tree_context()
parent_run_id = None if run_tree is None else run_tree.id
if run_tree is not None:
parent_run_id = None if run_tree is None else run_tree.id
elif parent is not None:
run_tree = RunTree.from_dotted_order(parent)
parent_run_id = run_tree.id
else:
parent_run_id = None
callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id)
if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:

View File

@@ -201,6 +201,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
local_tags=self.tags,
inheritable_metadata=config.get("metadata"),
local_metadata=self.metadata,
parent=config.get("parent"),
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
@@ -260,6 +261,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
local_tags=self.tags,
inheritable_metadata=config.get("metadata"),
local_metadata=self.metadata,
parent=config.get("parent"),
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),

View File

@@ -2896,6 +2896,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
parent=config.get("parent"),
)
for config in configs
]
@@ -3022,6 +3023,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
parent=config.get("parent"),
)
for config in configs
]
@@ -3482,6 +3484,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
parent=config.get("parent"),
)
# start the root run
run_manager = callback_manager.on_chain_start(

View File

@@ -104,6 +104,11 @@ class RunnableConfig(TypedDict, total=False):
will be generated.
"""
parent: Optional[str]
"""
The parent dotted order in the trace. If not provided, the parent will be inferred
from the tracing context."""
var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig()
@@ -144,6 +149,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
metadata={},
callbacks=None,
recursion_limit=25,
parent=None,
)
if var_config := var_child_runnable_config.get():
empty.update(
@@ -435,6 +441,7 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
parent=config.get("parent"),
)
@@ -455,6 +462,7 @@ def get_async_callback_manager_for_config(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
parent=config.get("parent"),
)

View File

@@ -272,6 +272,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
parent=config.get("parent"),
)
for config in configs
]
@@ -364,6 +365,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
parent=config.get("parent"),
)
for config in configs
]

View File

@@ -576,6 +576,7 @@ class ChildTool(BaseTool):
self.tags,
metadata,
self.metadata,
parent=config.get("parent"),
)
run_manager = callback_manager.on_tool_start(
@@ -683,6 +684,7 @@ class ChildTool(BaseTool):
self.tags,
metadata,
self.metadata,
parent=config.get("parent"),
)
run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description},

View File

@@ -1,5 +1,6 @@
import json
import sys
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
@@ -29,7 +30,15 @@ def _get_posts(client: Client) -> list:
return posts
def test_config_traceable_handoff() -> None:
@pytest.mark.parametrize(
"dotted_order",
[
None,
"20240716T222505213101Z6bc70600-21d8-41d0-b54b-175f80b02130.20240716T222505315288Z180de013-f114-439a-86e2-ac15d1393f5a",
"20240716T222505213101Z6bc70600-21d8-41d0-b54b-175f80b02130",
],
)
def test_config_traceable_handoff(dotted_order: Optional[str]) -> None:
mock_session = MagicMock()
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
@@ -61,7 +70,10 @@ def test_config_traceable_handoff() -> None:
my_parent_runnable = RunnableLambda(my_parent_function)
assert my_parent_runnable.invoke(1, {"callbacks": [tracer]}) == 6
assert (
my_parent_runnable.invoke(1, {"callbacks": [tracer], "parent": dotted_order})
== 6
)
posts = _get_posts(mock_client_)
# There should have been 6 runs created,
# one for each function invocation
@@ -78,10 +90,12 @@ def test_config_traceable_handoff() -> None:
trace_id = posts[0]["trace_id"]
last_dotted_order = None
parent_run_id = None
if dotted_order is not None:
parent_run_id = dotted_order.split(".")[-1].split("Z")[-1]
for name in ordered_names:
id_ = name_to_body[name]["id"]
parent_run_id_ = name_to_body[name]["parent_run_id"]
if parent_run_id_ is not None:
if parent_run_id is not None:
assert parent_run_id == parent_run_id_
assert name in name_to_body
# All within the same trace
@@ -101,7 +115,15 @@ def test_config_traceable_handoff() -> None:
@pytest.mark.skipif(
sys.version_info < (3, 11), reason="Asyncio context vars require Python 3.11+"
)
async def test_config_traceable_async_handoff() -> None:
@pytest.mark.parametrize(
"dotted_order",
[
None,
"20240716T222505213101Z6bc70600-21d8-41d0-b54b-175f80b02130.20240716T222505315288Z180de013-f114-439a-86e2-ac15d1393f5a",
"20240716T222505213101Z6bc70600-21d8-41d0-b54b-175f80b02130",
],
)
async def test_config_traceable_async_handoff(dotted_order: Optional[str]) -> None:
mock_session = MagicMock()
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
@@ -132,7 +154,9 @@ async def test_config_traceable_async_handoff() -> None:
return await my_function(a)
my_parent_runnable = RunnableLambda(my_parent_function) # type: ignore
result = await my_parent_runnable.ainvoke(1, {"callbacks": [tracer]})
result = await my_parent_runnable.ainvoke(
1, {"callbacks": [tracer], "parent": dotted_order}
)
assert result == 6
posts = _get_posts(mock_client_)
# There should have been 6 runs created,
@@ -150,10 +174,12 @@ async def test_config_traceable_async_handoff() -> None:
trace_id = posts[0]["trace_id"]
last_dotted_order = None
parent_run_id = None
if dotted_order is not None:
parent_run_id = dotted_order.split(".")[-1].split("Z")[-1]
for name in ordered_names:
id_ = name_to_body[name]["id"]
parent_run_id_ = name_to_body[name]["parent_run_id"]
if parent_run_id_ is not None:
if parent_run_id is not None:
assert parent_run_id == parent_run_id_
assert name in name_to_body
# All within the same trace