mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
Set Context in RunnableSequence & RunnableParallel (#25073)
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, AsyncGenerator, Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langsmith import Client, traceable
|
||||
from langsmith.run_helpers import tracing_context
|
||||
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
|
||||
@@ -199,3 +200,141 @@ def test_tracing_enable_disable(
|
||||
assert len(mock_posts) == 1
|
||||
else:
|
||||
assert not mock_posts
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"]
|
||||
)
|
||||
async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
|
||||
if method.startswith("a") and sys.version_info < (3, 11):
|
||||
pytest.skip("Asyncio context vars require Python 3.11+")
|
||||
mock_session = MagicMock()
|
||||
mock_client_ = Client(
|
||||
session=mock_session, api_key="test", auto_batch_tracing=False
|
||||
)
|
||||
tracer = LangChainTracer(client=mock_client_)
|
||||
|
||||
@RunnableLambda
|
||||
def my_child_function(a: int) -> int:
|
||||
return a + 2
|
||||
|
||||
if method.startswith("a"):
|
||||
|
||||
async def other_thing(a: int) -> AsyncGenerator[int, None]:
|
||||
yield 1
|
||||
|
||||
else:
|
||||
|
||||
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
|
||||
yield 1
|
||||
|
||||
parallel = RunnableParallel(
|
||||
chain_result=my_child_function.with_config(tags=["atag"]),
|
||||
other_thing=other_thing,
|
||||
)
|
||||
|
||||
def before(x: int) -> int:
|
||||
return x
|
||||
|
||||
def after(x: dict) -> int:
|
||||
return x["chain_result"]
|
||||
|
||||
sequence = before | parallel | after
|
||||
if method.startswith("a"):
|
||||
|
||||
@RunnableLambda # type: ignore
|
||||
async def parent(a: int) -> int:
|
||||
return await sequence.ainvoke(a)
|
||||
|
||||
else:
|
||||
|
||||
@RunnableLambda
|
||||
def parent(a: int) -> int:
|
||||
return sequence.invoke(a)
|
||||
|
||||
# Now run the chain and check the resulting posts
|
||||
cb = [tracer]
|
||||
if method == "invoke":
|
||||
res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore
|
||||
elif method == "ainvoke":
|
||||
res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore
|
||||
elif method == "stream":
|
||||
results = list(parent.stream(1, {"callbacks": cb})) # type: ignore
|
||||
res = results[-1]
|
||||
elif method == "astream":
|
||||
results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore
|
||||
res = results[-1]
|
||||
elif method == "batch":
|
||||
res = parent.batch([1], {"callbacks": cb})[0] # type: ignore
|
||||
elif method == "abatch":
|
||||
res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unknown method {method}")
|
||||
assert res == 3
|
||||
posts = _get_posts(mock_client_)
|
||||
name_order = [
|
||||
"parent",
|
||||
"RunnableSequence",
|
||||
"before",
|
||||
"RunnableParallel<chain_result,other_thing>",
|
||||
["my_child_function", "other_thing"],
|
||||
"after",
|
||||
]
|
||||
expected_parents = {
|
||||
"parent": None,
|
||||
"RunnableSequence": "parent",
|
||||
"before": "RunnableSequence",
|
||||
"RunnableParallel<chain_result,other_thing>": "RunnableSequence",
|
||||
"my_child_function": "RunnableParallel<chain_result,other_thing>",
|
||||
"other_thing": "RunnableParallel<chain_result,other_thing>",
|
||||
"after": "RunnableSequence",
|
||||
}
|
||||
assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order])
|
||||
prev_dotted_order = None
|
||||
dotted_order_map = {}
|
||||
id_map = {}
|
||||
parent_id_map = {}
|
||||
i = 0
|
||||
for name in name_order:
|
||||
if isinstance(name, list):
|
||||
for n in name:
|
||||
matching_post = next(
|
||||
p for p in posts[i : i + len(name)] if p["name"] == n
|
||||
)
|
||||
assert matching_post
|
||||
dotted_order = matching_post["dotted_order"]
|
||||
if prev_dotted_order is not None:
|
||||
assert dotted_order > prev_dotted_order
|
||||
dotted_order_map[n] = dotted_order
|
||||
id_map[n] = matching_post["id"]
|
||||
parent_id_map[n] = matching_post.get("parent_run_id")
|
||||
i += len(name)
|
||||
continue
|
||||
else:
|
||||
assert posts[i]["name"] == name
|
||||
dotted_order = posts[i]["dotted_order"]
|
||||
if prev_dotted_order is not None and not str(
|
||||
expected_parents[name]
|
||||
).startswith("RunnableParallel"):
|
||||
assert (
|
||||
dotted_order > prev_dotted_order
|
||||
), f"{name} not after {name_order[i-1]}"
|
||||
prev_dotted_order = dotted_order
|
||||
if name in dotted_order_map:
|
||||
raise ValueError(f"Duplicate name {name}")
|
||||
dotted_order_map[name] = dotted_order
|
||||
id_map[name] = posts[i]["id"]
|
||||
parent_id_map[name] = posts[i].get("parent_run_id")
|
||||
i += 1
|
||||
|
||||
# Now check the dotted orders
|
||||
for name, parent_ in expected_parents.items():
|
||||
dotted_order = dotted_order_map[name]
|
||||
if parent_ is not None:
|
||||
parent_dotted_order = dotted_order_map[parent_]
|
||||
assert dotted_order.startswith(
|
||||
parent_dotted_order
|
||||
), f"{name}, {parent_dotted_order} not in {dotted_order}"
|
||||
assert str(parent_id_map[name]) == str(id_map[parent_])
|
||||
else:
|
||||
assert dotted_order.split(".")[0] == dotted_order
|
||||
|
Reference in New Issue
Block a user