mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
Set Context in RunnableSequence & RunnableParallel (#25073)
This commit is contained in:
parent
71c0698ee4
commit
267855b3c1
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
@ -68,8 +69,8 @@ from langchain_core.runnables.utils import (
|
||||
Input,
|
||||
Output,
|
||||
accepts_config,
|
||||
accepts_context,
|
||||
accepts_run_manager,
|
||||
asyncio_accepts_context,
|
||||
create_model,
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
@ -1830,7 +1831,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
coro = acall_func_with_variable_args(
|
||||
func, input, config, run_manager, **kwargs
|
||||
)
|
||||
if accepts_context(asyncio.create_task):
|
||||
if asyncio_accepts_context():
|
||||
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
else:
|
||||
output = await coro
|
||||
@ -2156,7 +2157,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
iterator = iterator_
|
||||
try:
|
||||
while True:
|
||||
if accepts_context(asyncio.create_task):
|
||||
if asyncio_accepts_context():
|
||||
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
||||
py_anext(iterator), # type: ignore[arg-type]
|
||||
context=context,
|
||||
@ -2869,10 +2870,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||
)
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, config)
|
||||
if i == 0:
|
||||
input = step.invoke(input, config, **kwargs)
|
||||
input = context.run(step.invoke, input, config, **kwargs)
|
||||
else:
|
||||
input = step.invoke(input, config)
|
||||
input = context.run(step.invoke, input, config)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
@ -2907,10 +2910,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||
)
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, config)
|
||||
if i == 0:
|
||||
input = await step.ainvoke(input, config, **kwargs)
|
||||
part = functools.partial(step.ainvoke, input, config, **kwargs)
|
||||
else:
|
||||
input = await step.ainvoke(input, config)
|
||||
part = functools.partial(step.ainvoke, input, config)
|
||||
if asyncio_accepts_context():
|
||||
input = await asyncio.create_task(part(), context=context) # type: ignore
|
||||
else:
|
||||
input = await asyncio.create_task(part())
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
@ -3542,21 +3551,30 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
def _invoke_step(
|
||||
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
|
||||
) -> Any:
|
||||
child_config = patch_config(
|
||||
config,
|
||||
# mark each step as a child run
|
||||
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||
)
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
return context.run(
|
||||
step.invoke,
|
||||
input,
|
||||
child_config,
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = dict(self.steps__)
|
||||
|
||||
with get_executor_for_config(config) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
step.invoke,
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||
),
|
||||
)
|
||||
executor.submit(_invoke_step, step, input, config, key)
|
||||
for key, step in steps.items()
|
||||
]
|
||||
output = {key: future.result() for key, future in zip(steps, futures)}
|
||||
@ -3585,18 +3603,34 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
async def _ainvoke_step(
|
||||
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
|
||||
) -> Any:
|
||||
child_config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||
)
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
if asyncio_accepts_context():
|
||||
return await asyncio.create_task( # type: ignore
|
||||
step.ainvoke(input, child_config), context=context
|
||||
)
|
||||
else:
|
||||
return await asyncio.create_task(step.ainvoke(input, child_config))
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = dict(self.steps__)
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
step.ainvoke(
|
||||
_ainvoke_step(
|
||||
step,
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(f"map:key:{key}")
|
||||
),
|
||||
config,
|
||||
key,
|
||||
)
|
||||
for key, step in steps.items()
|
||||
)
|
||||
|
@ -348,7 +348,6 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
base["callbacks"] = mngr
|
||||
else:
|
||||
# base_callbacks is also a manager
|
||||
|
||||
manager = base_callbacks.__class__(
|
||||
parent_run_id=base_callbacks.parent_run_id
|
||||
or these_callbacks.parent_run_id,
|
||||
|
@ -1,12 +1,12 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import typing
|
||||
from contextvars import copy_context
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
_set_config_context,
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
@ -33,6 +34,7 @@ from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
asyncio_accepts_context,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
from langchain_core.utils.aiter import py_anext
|
||||
@ -172,9 +174,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
try:
|
||||
if self.exception_key and last_error is not None:
|
||||
input[self.exception_key] = last_error
|
||||
output = runnable.invoke(
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
output = context.run(
|
||||
runnable.invoke,
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
@ -220,11 +225,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
try:
|
||||
if self.exception_key and last_error is not None:
|
||||
input[self.exception_key] = last_error
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
coro = runnable.ainvoke(input, child_config, **kwargs)
|
||||
if asyncio_accepts_context():
|
||||
output = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
else:
|
||||
output = await coro
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
@ -460,12 +468,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
try:
|
||||
if self.exception_key and last_error is not None:
|
||||
input[self.exception_key] = last_error
|
||||
stream = runnable.stream(
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
stream = context.run(
|
||||
runnable.stream,
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
chunk = next(stream)
|
||||
chunk: Output = context.run(next, stream) # type: ignore
|
||||
except self.exceptions_to_handle as e:
|
||||
first_error = e if first_error is None else first_error
|
||||
last_error = e
|
||||
@ -520,12 +531,21 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
try:
|
||||
if self.exception_key and last_error is not None:
|
||||
input[self.exception_key] = last_error
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
stream = runnable.astream(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
child_config,
|
||||
**kwargs,
|
||||
)
|
||||
chunk = await cast(Awaitable[Output], py_anext(stream))
|
||||
if asyncio_accepts_context():
|
||||
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
||||
py_anext(stream), # type: ignore[arg-type]
|
||||
context=context,
|
||||
)
|
||||
else:
|
||||
chunk = cast(Output, await py_anext(stream))
|
||||
except self.exceptions_to_handle as e:
|
||||
first_error = e if first_error is None else first_error
|
||||
last_error = e
|
||||
|
@ -118,6 +118,11 @@ def accepts_context(callable: Callable[..., Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def asyncio_accepts_context() -> bool:
|
||||
return accepts_context(asyncio.create_task)
|
||||
|
||||
|
||||
class IsLocalDict(ast.NodeVisitor):
|
||||
"""Check if a name is a local dict."""
|
||||
|
||||
|
@ -87,7 +87,7 @@ from langchain_core.runnables.config import (
|
||||
patch_config,
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.runnables.utils import accepts_context
|
||||
from langchain_core.runnables.utils import asyncio_accepts_context
|
||||
from langchain_core.utils.function_calling import (
|
||||
_parse_google_docstring,
|
||||
_py_38_safe_origin,
|
||||
@ -694,7 +694,7 @@ class ChildTool(BaseTool):
|
||||
tool_kwargs[config_param] = config
|
||||
|
||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
if accepts_context(asyncio.create_task):
|
||||
if asyncio_accepts_context():
|
||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
else:
|
||||
response = await coro
|
||||
|
@ -1,12 +1,13 @@
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, AsyncGenerator, Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langsmith import Client, traceable
|
||||
from langsmith.run_helpers import tracing_context
|
||||
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
|
||||
@ -199,3 +200,141 @@ def test_tracing_enable_disable(
|
||||
assert len(mock_posts) == 1
|
||||
else:
|
||||
assert not mock_posts
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"]
|
||||
)
|
||||
async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
|
||||
if method.startswith("a") and sys.version_info < (3, 11):
|
||||
pytest.skip("Asyncio context vars require Python 3.11+")
|
||||
mock_session = MagicMock()
|
||||
mock_client_ = Client(
|
||||
session=mock_session, api_key="test", auto_batch_tracing=False
|
||||
)
|
||||
tracer = LangChainTracer(client=mock_client_)
|
||||
|
||||
@RunnableLambda
|
||||
def my_child_function(a: int) -> int:
|
||||
return a + 2
|
||||
|
||||
if method.startswith("a"):
|
||||
|
||||
async def other_thing(a: int) -> AsyncGenerator[int, None]:
|
||||
yield 1
|
||||
|
||||
else:
|
||||
|
||||
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
|
||||
yield 1
|
||||
|
||||
parallel = RunnableParallel(
|
||||
chain_result=my_child_function.with_config(tags=["atag"]),
|
||||
other_thing=other_thing,
|
||||
)
|
||||
|
||||
def before(x: int) -> int:
|
||||
return x
|
||||
|
||||
def after(x: dict) -> int:
|
||||
return x["chain_result"]
|
||||
|
||||
sequence = before | parallel | after
|
||||
if method.startswith("a"):
|
||||
|
||||
@RunnableLambda # type: ignore
|
||||
async def parent(a: int) -> int:
|
||||
return await sequence.ainvoke(a)
|
||||
|
||||
else:
|
||||
|
||||
@RunnableLambda
|
||||
def parent(a: int) -> int:
|
||||
return sequence.invoke(a)
|
||||
|
||||
# Now run the chain and check the resulting posts
|
||||
cb = [tracer]
|
||||
if method == "invoke":
|
||||
res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore
|
||||
elif method == "ainvoke":
|
||||
res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore
|
||||
elif method == "stream":
|
||||
results = list(parent.stream(1, {"callbacks": cb})) # type: ignore
|
||||
res = results[-1]
|
||||
elif method == "astream":
|
||||
results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore
|
||||
res = results[-1]
|
||||
elif method == "batch":
|
||||
res = parent.batch([1], {"callbacks": cb})[0] # type: ignore
|
||||
elif method == "abatch":
|
||||
res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unknown method {method}")
|
||||
assert res == 3
|
||||
posts = _get_posts(mock_client_)
|
||||
name_order = [
|
||||
"parent",
|
||||
"RunnableSequence",
|
||||
"before",
|
||||
"RunnableParallel<chain_result,other_thing>",
|
||||
["my_child_function", "other_thing"],
|
||||
"after",
|
||||
]
|
||||
expected_parents = {
|
||||
"parent": None,
|
||||
"RunnableSequence": "parent",
|
||||
"before": "RunnableSequence",
|
||||
"RunnableParallel<chain_result,other_thing>": "RunnableSequence",
|
||||
"my_child_function": "RunnableParallel<chain_result,other_thing>",
|
||||
"other_thing": "RunnableParallel<chain_result,other_thing>",
|
||||
"after": "RunnableSequence",
|
||||
}
|
||||
assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order])
|
||||
prev_dotted_order = None
|
||||
dotted_order_map = {}
|
||||
id_map = {}
|
||||
parent_id_map = {}
|
||||
i = 0
|
||||
for name in name_order:
|
||||
if isinstance(name, list):
|
||||
for n in name:
|
||||
matching_post = next(
|
||||
p for p in posts[i : i + len(name)] if p["name"] == n
|
||||
)
|
||||
assert matching_post
|
||||
dotted_order = matching_post["dotted_order"]
|
||||
if prev_dotted_order is not None:
|
||||
assert dotted_order > prev_dotted_order
|
||||
dotted_order_map[n] = dotted_order
|
||||
id_map[n] = matching_post["id"]
|
||||
parent_id_map[n] = matching_post.get("parent_run_id")
|
||||
i += len(name)
|
||||
continue
|
||||
else:
|
||||
assert posts[i]["name"] == name
|
||||
dotted_order = posts[i]["dotted_order"]
|
||||
if prev_dotted_order is not None and not str(
|
||||
expected_parents[name]
|
||||
).startswith("RunnableParallel"):
|
||||
assert (
|
||||
dotted_order > prev_dotted_order
|
||||
), f"{name} not after {name_order[i-1]}"
|
||||
prev_dotted_order = dotted_order
|
||||
if name in dotted_order_map:
|
||||
raise ValueError(f"Duplicate name {name}")
|
||||
dotted_order_map[name] = dotted_order
|
||||
id_map[name] = posts[i]["id"]
|
||||
parent_id_map[name] = posts[i].get("parent_run_id")
|
||||
i += 1
|
||||
|
||||
# Now check the dotted orders
|
||||
for name, parent_ in expected_parents.items():
|
||||
dotted_order = dotted_order_map[name]
|
||||
if parent_ is not None:
|
||||
parent_dotted_order = dotted_order_map[parent_]
|
||||
assert dotted_order.startswith(
|
||||
parent_dotted_order
|
||||
), f"{name}, {parent_dotted_order} not in {dotted_order}"
|
||||
assert str(parent_id_map[name]) == str(id_map[parent_])
|
||||
else:
|
||||
assert dotted_order.split(".")[0] == dotted_order
|
||||
|
Loading…
Reference in New Issue
Block a user