Set Context in RunnableSequence & RunnableParallel (#25073)

This commit is contained in:
William FH 2024-08-06 11:10:37 -07:00 committed by GitHub
parent 71c0698ee4
commit 267855b3c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 234 additions and 37 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import collections import collections
import functools
import inspect import inspect
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -68,8 +69,8 @@ from langchain_core.runnables.utils import (
Input, Input,
Output, Output,
accepts_config, accepts_config,
accepts_context,
accepts_run_manager, accepts_run_manager,
asyncio_accepts_context,
create_model, create_model,
gather_with_concurrency, gather_with_concurrency,
get_function_first_arg_dict_keys, get_function_first_arg_dict_keys,
@ -1830,7 +1831,7 @@ class Runnable(Generic[Input, Output], ABC):
coro = acall_func_with_variable_args( coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs func, input, config, run_manager, **kwargs
) )
if accepts_context(asyncio.create_task): if asyncio_accepts_context():
output: Output = await asyncio.create_task(coro, context=context) # type: ignore output: Output = await asyncio.create_task(coro, context=context) # type: ignore
else: else:
output = await coro output = await coro
@ -2156,7 +2157,7 @@ class Runnable(Generic[Input, Output], ABC):
iterator = iterator_ iterator = iterator_
try: try:
while True: while True:
if accepts_context(asyncio.create_task): if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg] chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(iterator), # type: ignore[arg-type] py_anext(iterator), # type: ignore[arg-type]
context=context, context=context,
@ -2869,10 +2870,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config( config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}") config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
) )
context = copy_context()
context.run(_set_config_context, config)
if i == 0: if i == 0:
input = step.invoke(input, config, **kwargs) input = context.run(step.invoke, input, config, **kwargs)
else: else:
input = step.invoke(input, config) input = context.run(step.invoke, input, config)
# finish the root run # finish the root run
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
@ -2907,10 +2910,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config( config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}") config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
) )
context = copy_context()
context.run(_set_config_context, config)
if i == 0: if i == 0:
input = await step.ainvoke(input, config, **kwargs) part = functools.partial(step.ainvoke, input, config, **kwargs)
else: else:
input = await step.ainvoke(input, config) part = functools.partial(step.ainvoke, input, config)
if asyncio_accepts_context():
input = await asyncio.create_task(part(), context=context) # type: ignore
else:
input = await asyncio.create_task(part())
# finish the root run # finish the root run
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
@ -3542,21 +3551,30 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
def _invoke_step(
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
) -> Any:
child_config = patch_config(
config,
# mark each step as a child run
callbacks=run_manager.get_child(f"map:key:{key}"),
)
context = copy_context()
context.run(_set_config_context, child_config)
return context.run(
step.invoke,
input,
child_config,
)
# gather results from all steps # gather results from all steps
try: try:
# copy to avoid issues from the caller mutating the steps during invoke() # copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps__) steps = dict(self.steps__)
with get_executor_for_config(config) as executor: with get_executor_for_config(config) as executor:
futures = [ futures = [
executor.submit( executor.submit(_invoke_step, step, input, config, key)
step.invoke,
input,
# mark each step as a child run
patch_config(
config,
callbacks=run_manager.get_child(f"map:key:{key}"),
),
)
for key, step in steps.items() for key, step in steps.items()
] ]
output = {key: future.result() for key, future in zip(steps, futures)} output = {key: future.result() for key, future in zip(steps, futures)}
@ -3585,18 +3603,34 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
async def _ainvoke_step(
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
) -> Any:
child_config = patch_config(
config,
callbacks=run_manager.get_child(f"map:key:{key}"),
)
context = copy_context()
context.run(_set_config_context, child_config)
if asyncio_accepts_context():
return await asyncio.create_task( # type: ignore
step.ainvoke(input, child_config), context=context
)
else:
return await asyncio.create_task(step.ainvoke(input, child_config))
# gather results from all steps # gather results from all steps
try: try:
# copy to avoid issues from the caller mutating the steps during invoke() # copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps__) steps = dict(self.steps__)
results = await asyncio.gather( results = await asyncio.gather(
*( *(
step.ainvoke( _ainvoke_step(
step,
input, input,
# mark each step as a child run # mark each step as a child run
patch_config( config,
config, callbacks=run_manager.get_child(f"map:key:{key}") key,
),
) )
for key, step in steps.items() for key, step in steps.items()
) )

View File

@ -348,7 +348,6 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base["callbacks"] = mngr base["callbacks"] = mngr
else: else:
# base_callbacks is also a manager # base_callbacks is also a manager
manager = base_callbacks.__class__( manager = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id, or these_callbacks.parent_run_id,

View File

@ -1,12 +1,12 @@
import asyncio import asyncio
import inspect import inspect
import typing import typing
from contextvars import copy_context
from functools import wraps from functools import wraps
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator, AsyncIterator,
Awaitable,
Dict, Dict,
Iterator, Iterator,
List, List,
@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
_set_config_context,
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
get_callback_manager_for_config, get_callback_manager_for_config,
@ -33,6 +34,7 @@ from langchain_core.runnables.utils import (
ConfigurableFieldSpec, ConfigurableFieldSpec,
Input, Input,
Output, Output,
asyncio_accepts_context,
get_unique_config_specs, get_unique_config_specs,
) )
from langchain_core.utils.aiter import py_anext from langchain_core.utils.aiter import py_anext
@ -172,9 +174,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
try: try:
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
output = runnable.invoke( child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
output = context.run(
runnable.invoke,
input, input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs, **kwargs,
) )
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
@ -220,11 +225,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
try: try:
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
output = await runnable.ainvoke( child_config = patch_config(config, callbacks=run_manager.get_child())
input, context = copy_context()
patch_config(config, callbacks=run_manager.get_child()), context.run(_set_config_context, child_config)
**kwargs, coro = runnable.ainvoke(input, child_config, **kwargs)
) if asyncio_accepts_context():
output = await asyncio.create_task(coro, context=context) # type: ignore
else:
output = await coro
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
first_error = e first_error = e
@ -460,12 +468,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
try: try:
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
stream = runnable.stream( child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
stream = context.run(
runnable.stream,
input, input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs, **kwargs,
) )
chunk = next(stream) chunk: Output = context.run(next, stream) # type: ignore
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error first_error = e if first_error is None else first_error
last_error = e last_error = e
@ -520,12 +531,21 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
try: try:
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
stream = runnable.astream( stream = runnable.astream(
input, input,
patch_config(config, callbacks=run_manager.get_child()), child_config,
**kwargs, **kwargs,
) )
chunk = await cast(Awaitable[Output], py_anext(stream)) if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(stream), # type: ignore[arg-type]
context=context,
)
else:
chunk = cast(Output, await py_anext(stream))
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error first_error = e if first_error is None else first_error
last_error = e last_error = e

View File

@ -118,6 +118,11 @@ def accepts_context(callable: Callable[..., Any]) -> bool:
return False return False
@lru_cache(maxsize=1)
def asyncio_accepts_context() -> bool:
return accepts_context(asyncio.create_task)
class IsLocalDict(ast.NodeVisitor): class IsLocalDict(ast.NodeVisitor):
"""Check if a name is a local dict.""" """Check if a name is a local dict."""

View File

@ -87,7 +87,7 @@ from langchain_core.runnables.config import (
patch_config, patch_config,
run_in_executor, run_in_executor,
) )
from langchain_core.runnables.utils import accepts_context from langchain_core.runnables.utils import asyncio_accepts_context
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
_parse_google_docstring, _parse_google_docstring,
_py_38_safe_origin, _py_38_safe_origin,
@ -694,7 +694,7 @@ class ChildTool(BaseTool):
tool_kwargs[config_param] = config tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs) coro = context.run(self._arun, *tool_args, **tool_kwargs)
if accepts_context(asyncio.create_task): if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore response = await asyncio.create_task(coro, context=context) # type: ignore
else: else:
response = await coro response = await coro

View File

@ -1,12 +1,13 @@
import json import json
import sys import sys
from typing import Any, AsyncGenerator, Generator
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from langsmith import Client, traceable from langsmith import Client, traceable
from langsmith.run_helpers import tracing_context from langsmith.run_helpers import tracing_context
from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.base import RunnableLambda, RunnableParallel
from langchain_core.tracers.langchain import LangChainTracer from langchain_core.tracers.langchain import LangChainTracer
@ -199,3 +200,141 @@ def test_tracing_enable_disable(
assert len(mock_posts) == 1 assert len(mock_posts) == 1
else: else:
assert not mock_posts assert not mock_posts
@pytest.mark.parametrize(
"method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"]
)
async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
if method.startswith("a") and sys.version_info < (3, 11):
pytest.skip("Asyncio context vars require Python 3.11+")
mock_session = MagicMock()
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
)
tracer = LangChainTracer(client=mock_client_)
@RunnableLambda
def my_child_function(a: int) -> int:
return a + 2
if method.startswith("a"):
async def other_thing(a: int) -> AsyncGenerator[int, None]:
yield 1
else:
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
yield 1
parallel = RunnableParallel(
chain_result=my_child_function.with_config(tags=["atag"]),
other_thing=other_thing,
)
def before(x: int) -> int:
return x
def after(x: dict) -> int:
return x["chain_result"]
sequence = before | parallel | after
if method.startswith("a"):
@RunnableLambda # type: ignore
async def parent(a: int) -> int:
return await sequence.ainvoke(a)
else:
@RunnableLambda
def parent(a: int) -> int:
return sequence.invoke(a)
# Now run the chain and check the resulting posts
cb = [tracer]
if method == "invoke":
res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore
elif method == "ainvoke":
res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore
elif method == "stream":
results = list(parent.stream(1, {"callbacks": cb})) # type: ignore
res = results[-1]
elif method == "astream":
results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore
res = results[-1]
elif method == "batch":
res = parent.batch([1], {"callbacks": cb})[0] # type: ignore
elif method == "abatch":
res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore
else:
raise ValueError(f"Unknown method {method}")
assert res == 3
posts = _get_posts(mock_client_)
name_order = [
"parent",
"RunnableSequence",
"before",
"RunnableParallel<chain_result,other_thing>",
["my_child_function", "other_thing"],
"after",
]
expected_parents = {
"parent": None,
"RunnableSequence": "parent",
"before": "RunnableSequence",
"RunnableParallel<chain_result,other_thing>": "RunnableSequence",
"my_child_function": "RunnableParallel<chain_result,other_thing>",
"other_thing": "RunnableParallel<chain_result,other_thing>",
"after": "RunnableSequence",
}
assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order])
prev_dotted_order = None
dotted_order_map = {}
id_map = {}
parent_id_map = {}
i = 0
for name in name_order:
if isinstance(name, list):
for n in name:
matching_post = next(
p for p in posts[i : i + len(name)] if p["name"] == n
)
assert matching_post
dotted_order = matching_post["dotted_order"]
if prev_dotted_order is not None:
assert dotted_order > prev_dotted_order
dotted_order_map[n] = dotted_order
id_map[n] = matching_post["id"]
parent_id_map[n] = matching_post.get("parent_run_id")
i += len(name)
continue
else:
assert posts[i]["name"] == name
dotted_order = posts[i]["dotted_order"]
if prev_dotted_order is not None and not str(
expected_parents[name]
).startswith("RunnableParallel"):
assert (
dotted_order > prev_dotted_order
), f"{name} not after {name_order[i-1]}"
prev_dotted_order = dotted_order
if name in dotted_order_map:
raise ValueError(f"Duplicate name {name}")
dotted_order_map[name] = dotted_order
id_map[name] = posts[i]["id"]
parent_id_map[name] = posts[i].get("parent_run_id")
i += 1
# Now check the dotted orders
for name, parent_ in expected_parents.items():
dotted_order = dotted_order_map[name]
if parent_ is not None:
parent_dotted_order = dotted_order_map[parent_]
assert dotted_order.startswith(
parent_dotted_order
), f"{name}, {parent_dotted_order} not in {dotted_order}"
assert str(parent_id_map[name]) == str(id_map[parent_])
else:
assert dotted_order.split(".")[0] == dotted_order