Set Context in RunnableSequence & RunnableParallel (#25073)

This commit is contained in:
William FH
2024-08-06 11:10:37 -07:00
committed by GitHub
parent 71c0698ee4
commit 267855b3c1
6 changed files with 234 additions and 37 deletions

View File

@@ -1,12 +1,13 @@
import json
import sys
from typing import Any, AsyncGenerator, Generator
from unittest.mock import MagicMock, patch
import pytest
from langsmith import Client, traceable
from langsmith.run_helpers import tracing_context
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
from langchain_core.tracers.langchain import LangChainTracer
@@ -199,3 +200,141 @@ def test_tracing_enable_disable(
assert len(mock_posts) == 1
else:
assert not mock_posts
@pytest.mark.parametrize(
"method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"]
)
async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
if method.startswith("a") and sys.version_info < (3, 11):
pytest.skip("Asyncio context vars require Python 3.11+")
mock_session = MagicMock()
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
)
tracer = LangChainTracer(client=mock_client_)
@RunnableLambda
def my_child_function(a: int) -> int:
return a + 2
if method.startswith("a"):
async def other_thing(a: int) -> AsyncGenerator[int, None]:
yield 1
else:
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
yield 1
parallel = RunnableParallel(
chain_result=my_child_function.with_config(tags=["atag"]),
other_thing=other_thing,
)
def before(x: int) -> int:
return x
def after(x: dict) -> int:
return x["chain_result"]
sequence = before | parallel | after
if method.startswith("a"):
@RunnableLambda # type: ignore
async def parent(a: int) -> int:
return await sequence.ainvoke(a)
else:
@RunnableLambda
def parent(a: int) -> int:
return sequence.invoke(a)
# Now run the chain and check the resulting posts
cb = [tracer]
if method == "invoke":
res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore
elif method == "ainvoke":
res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore
elif method == "stream":
results = list(parent.stream(1, {"callbacks": cb})) # type: ignore
res = results[-1]
elif method == "astream":
results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore
res = results[-1]
elif method == "batch":
res = parent.batch([1], {"callbacks": cb})[0] # type: ignore
elif method == "abatch":
res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore
else:
raise ValueError(f"Unknown method {method}")
assert res == 3
posts = _get_posts(mock_client_)
name_order = [
"parent",
"RunnableSequence",
"before",
"RunnableParallel<chain_result,other_thing>",
["my_child_function", "other_thing"],
"after",
]
expected_parents = {
"parent": None,
"RunnableSequence": "parent",
"before": "RunnableSequence",
"RunnableParallel<chain_result,other_thing>": "RunnableSequence",
"my_child_function": "RunnableParallel<chain_result,other_thing>",
"other_thing": "RunnableParallel<chain_result,other_thing>",
"after": "RunnableSequence",
}
assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order])
prev_dotted_order = None
dotted_order_map = {}
id_map = {}
parent_id_map = {}
i = 0
for name in name_order:
if isinstance(name, list):
for n in name:
matching_post = next(
p for p in posts[i : i + len(name)] if p["name"] == n
)
assert matching_post
dotted_order = matching_post["dotted_order"]
if prev_dotted_order is not None:
assert dotted_order > prev_dotted_order
dotted_order_map[n] = dotted_order
id_map[n] = matching_post["id"]
parent_id_map[n] = matching_post.get("parent_run_id")
i += len(name)
continue
else:
assert posts[i]["name"] == name
dotted_order = posts[i]["dotted_order"]
if prev_dotted_order is not None and not str(
expected_parents[name]
).startswith("RunnableParallel"):
assert (
dotted_order > prev_dotted_order
), f"{name} not after {name_order[i-1]}"
prev_dotted_order = dotted_order
if name in dotted_order_map:
raise ValueError(f"Duplicate name {name}")
dotted_order_map[name] = dotted_order
id_map[name] = posts[i]["id"]
parent_id_map[name] = posts[i].get("parent_run_id")
i += 1
# Now check the dotted orders
for name, parent_ in expected_parents.items():
dotted_order = dotted_order_map[name]
if parent_ is not None:
parent_dotted_order = dotted_order_map[parent_]
assert dotted_order.startswith(
parent_dotted_order
), f"{name}, {parent_dotted_order} not in {dotted_order}"
assert str(parent_id_map[name]) == str(id_map[parent_])
else:
assert dotted_order.split(".")[0] == dotted_order