Set Context in RunnableSequence & RunnableParallel (#25073)

This commit is contained in:
William FH 2024-08-06 11:10:37 -07:00 committed by GitHub
parent 71c0698ee4
commit 267855b3c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 234 additions and 37 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import collections
import functools
import inspect
import threading
from abc import ABC, abstractmethod
@ -68,8 +69,8 @@ from langchain_core.runnables.utils import (
Input,
Output,
accepts_config,
accepts_context,
accepts_run_manager,
asyncio_accepts_context,
create_model,
gather_with_concurrency,
get_function_first_arg_dict_keys,
@ -1830,7 +1831,7 @@ class Runnable(Generic[Input, Output], ABC):
coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs
)
if accepts_context(asyncio.create_task):
if asyncio_accepts_context():
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
else:
output = await coro
@ -2156,7 +2157,7 @@ class Runnable(Generic[Input, Output], ABC):
iterator = iterator_
try:
while True:
if accepts_context(asyncio.create_task):
if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(iterator), # type: ignore[arg-type]
context=context,
@ -2869,10 +2870,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
context = copy_context()
context.run(_set_config_context, config)
if i == 0:
input = step.invoke(input, config, **kwargs)
input = context.run(step.invoke, input, config, **kwargs)
else:
input = step.invoke(input, config)
input = context.run(step.invoke, input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
@ -2907,10 +2910,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
context = copy_context()
context.run(_set_config_context, config)
if i == 0:
input = await step.ainvoke(input, config, **kwargs)
part = functools.partial(step.ainvoke, input, config, **kwargs)
else:
input = await step.ainvoke(input, config)
part = functools.partial(step.ainvoke, input, config)
if asyncio_accepts_context():
input = await asyncio.create_task(part(), context=context) # type: ignore
else:
input = await asyncio.create_task(part())
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
@ -3542,21 +3551,30 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
run_id=config.pop("run_id", None),
)
def _invoke_step(
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
) -> Any:
child_config = patch_config(
config,
# mark each step as a child run
callbacks=run_manager.get_child(f"map:key:{key}"),
)
context = copy_context()
context.run(_set_config_context, child_config)
return context.run(
step.invoke,
input,
child_config,
)
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps__)
with get_executor_for_config(config) as executor:
futures = [
executor.submit(
step.invoke,
input,
# mark each step as a child run
patch_config(
config,
callbacks=run_manager.get_child(f"map:key:{key}"),
),
)
executor.submit(_invoke_step, step, input, config, key)
for key, step in steps.items()
]
output = {key: future.result() for key, future in zip(steps, futures)}
@ -3585,18 +3603,34 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
run_id=config.pop("run_id", None),
)
async def _ainvoke_step(
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
) -> Any:
child_config = patch_config(
config,
callbacks=run_manager.get_child(f"map:key:{key}"),
)
context = copy_context()
context.run(_set_config_context, child_config)
if asyncio_accepts_context():
return await asyncio.create_task( # type: ignore
step.ainvoke(input, child_config), context=context
)
else:
return await asyncio.create_task(step.ainvoke(input, child_config))
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps__)
results = await asyncio.gather(
*(
step.ainvoke(
_ainvoke_step(
step,
input,
# mark each step as a child run
patch_config(
config, callbacks=run_manager.get_child(f"map:key:{key}")
),
config,
key,
)
for key, step in steps.items()
)

View File

@ -348,7 +348,6 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
manager = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id,

View File

@ -1,12 +1,12 @@
import asyncio
import inspect
import typing
from contextvars import copy_context
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Dict,
Iterator,
List,
@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
@ -33,6 +34,7 @@ from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Input,
Output,
asyncio_accepts_context,
get_unique_config_specs,
)
from langchain_core.utils.aiter import py_anext
@ -172,9 +174,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
output = runnable.invoke(
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
output = context.run(
runnable.invoke,
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
@ -220,11 +225,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
output = await runnable.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
coro = runnable.ainvoke(input, child_config, **kwargs)
if asyncio_accepts_context():
output = await asyncio.create_task(coro, context=context) # type: ignore
else:
output = await coro
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
@ -460,12 +468,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
stream = runnable.stream(
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
stream = context.run(
runnable.stream,
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
chunk = next(stream)
chunk: Output = context.run(next, stream) # type: ignore
except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error
last_error = e
@ -520,12 +531,21 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
stream = runnable.astream(
input,
patch_config(config, callbacks=run_manager.get_child()),
child_config,
**kwargs,
)
chunk = await cast(Awaitable[Output], py_anext(stream))
if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(stream), # type: ignore[arg-type]
context=context,
)
else:
chunk = cast(Output, await py_anext(stream))
except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error
last_error = e

View File

@ -118,6 +118,11 @@ def accepts_context(callable: Callable[..., Any]) -> bool:
return False
@lru_cache(maxsize=1)
def asyncio_accepts_context() -> bool:
return accepts_context(asyncio.create_task)
class IsLocalDict(ast.NodeVisitor):
"""Check if a name is a local dict."""

View File

@ -87,7 +87,7 @@ from langchain_core.runnables.config import (
patch_config,
run_in_executor,
)
from langchain_core.runnables.utils import accepts_context
from langchain_core.runnables.utils import asyncio_accepts_context
from langchain_core.utils.function_calling import (
_parse_google_docstring,
_py_38_safe_origin,
@ -694,7 +694,7 @@ class ChildTool(BaseTool):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if accepts_context(asyncio.create_task):
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro

View File

@ -1,12 +1,13 @@
import json
import sys
from typing import Any, AsyncGenerator, Generator
from unittest.mock import MagicMock, patch
import pytest
from langsmith import Client, traceable
from langsmith.run_helpers import tracing_context
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
from langchain_core.tracers.langchain import LangChainTracer
@ -199,3 +200,141 @@ def test_tracing_enable_disable(
assert len(mock_posts) == 1
else:
assert not mock_posts
@pytest.mark.parametrize(
"method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"]
)
async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
if method.startswith("a") and sys.version_info < (3, 11):
pytest.skip("Asyncio context vars require Python 3.11+")
mock_session = MagicMock()
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
)
tracer = LangChainTracer(client=mock_client_)
@RunnableLambda
def my_child_function(a: int) -> int:
return a + 2
if method.startswith("a"):
async def other_thing(a: int) -> AsyncGenerator[int, None]:
yield 1
else:
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
yield 1
parallel = RunnableParallel(
chain_result=my_child_function.with_config(tags=["atag"]),
other_thing=other_thing,
)
def before(x: int) -> int:
return x
def after(x: dict) -> int:
return x["chain_result"]
sequence = before | parallel | after
if method.startswith("a"):
@RunnableLambda # type: ignore
async def parent(a: int) -> int:
return await sequence.ainvoke(a)
else:
@RunnableLambda
def parent(a: int) -> int:
return sequence.invoke(a)
# Now run the chain and check the resulting posts
cb = [tracer]
if method == "invoke":
res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore
elif method == "ainvoke":
res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore
elif method == "stream":
results = list(parent.stream(1, {"callbacks": cb})) # type: ignore
res = results[-1]
elif method == "astream":
results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore
res = results[-1]
elif method == "batch":
res = parent.batch([1], {"callbacks": cb})[0] # type: ignore
elif method == "abatch":
res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore
else:
raise ValueError(f"Unknown method {method}")
assert res == 3
posts = _get_posts(mock_client_)
name_order = [
"parent",
"RunnableSequence",
"before",
"RunnableParallel<chain_result,other_thing>",
["my_child_function", "other_thing"],
"after",
]
expected_parents = {
"parent": None,
"RunnableSequence": "parent",
"before": "RunnableSequence",
"RunnableParallel<chain_result,other_thing>": "RunnableSequence",
"my_child_function": "RunnableParallel<chain_result,other_thing>",
"other_thing": "RunnableParallel<chain_result,other_thing>",
"after": "RunnableSequence",
}
assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order])
prev_dotted_order = None
dotted_order_map = {}
id_map = {}
parent_id_map = {}
i = 0
for name in name_order:
if isinstance(name, list):
for n in name:
matching_post = next(
p for p in posts[i : i + len(name)] if p["name"] == n
)
assert matching_post
dotted_order = matching_post["dotted_order"]
if prev_dotted_order is not None:
assert dotted_order > prev_dotted_order
dotted_order_map[n] = dotted_order
id_map[n] = matching_post["id"]
parent_id_map[n] = matching_post.get("parent_run_id")
i += len(name)
continue
else:
assert posts[i]["name"] == name
dotted_order = posts[i]["dotted_order"]
if prev_dotted_order is not None and not str(
expected_parents[name]
).startswith("RunnableParallel"):
assert (
dotted_order > prev_dotted_order
), f"{name} not after {name_order[i-1]}"
prev_dotted_order = dotted_order
if name in dotted_order_map:
raise ValueError(f"Duplicate name {name}")
dotted_order_map[name] = dotted_order
id_map[name] = posts[i]["id"]
parent_id_map[name] = posts[i].get("parent_run_id")
i += 1
# Now check the dotted orders
for name, parent_ in expected_parents.items():
dotted_order = dotted_order_map[name]
if parent_ is not None:
parent_dotted_order = dotted_order_map[parent_]
assert dotted_order.startswith(
parent_dotted_order
), f"{name}, {parent_dotted_order} not in {dotted_order}"
assert str(parent_id_map[name]) == str(id_map[parent_])
else:
assert dotted_order.split(".")[0] == dotted_order