Fetch runnable config from context var inside runnable lambda and runnable generator (#15334)

- easier to write custom logic/loops with automatic tracing
- if you don't want to streaming support write a regular function and
pass to RunnableLambda
- if you do want streaming write a generator and pass it to
RunnableGenerator

```py
import json
from typing import AsyncIterator

from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable, RunnableGenerator, RunnablePassthrough
from langchain_core.tools import BaseTool

from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.chat_models import ChatOpenAI
from langchain.tools.render import format_tool_to_openai_function


def _get_tavily():
    from langchain.tools.tavily_search import TavilySearchResults
    from langchain.utilities.tavily_search import TavilySearchAPIWrapper

    tavily_search = TavilySearchAPIWrapper()
    return TavilySearchResults(api_wrapper=tavily_search)


async def _agent_executor_generator(
    input: AsyncIterator[list[BaseMessage]],
    *,
    max_iterations: int = 10,
    tools: dict[str, BaseTool],
    agent: Runnable[list[BaseMessage], BaseMessage],
    parser: Runnable[BaseMessage, AgentAction | AgentFinish],
) -> AsyncIterator[BaseMessage]:
    messages = [m async for mm in input for m in mm]
    for _ in range(max_iterations):
        next_message = await agent.ainvoke(messages)
        yield next_message
        messages.append(next_message)

        parsed = await parser.ainvoke(next_message)
        if isinstance(parsed, AgentAction):
            result = await tools[parsed.tool].ainvoke(parsed.tool_input)
            next_message = FunctionMessage(name=parsed.tool, content=json.dumps(result))
            yield next_message
            messages.append(next_message)
        elif isinstance(parsed, AgentFinish):
            return


def get_agent_executor(tools: list[BaseTool], system_message: str):
    llm = ChatOpenAI(model="gpt-4-1106-preview", temperature=0, streaming=True)
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_message),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )
    llm_with_tools = llm.bind(
        functions=[format_tool_to_openai_function(t) for t in tools]
    )

    agent = {"messages": RunnablePassthrough()} | prompt | llm_with_tools
    parser = OpenAIFunctionsAgentOutputParser()
    executor = RunnableGenerator(_agent_executor_generator)
    return executor.bind(
        tools={tool.name for tool in tools}, agent=agent, parser=parser
    )


agent = get_agent_executor([_get_tavily()], "You are a very nice agent!")


async def main():
    async for message in agent.astream(
        [HumanMessage(content="whats the weather in sf tomorrow?")]
    ):
        print(message)


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())
```

results in this trace
https://smith.langchain.com/public/fa17f05d-9724-4d08-8fa1-750f8fcd051b/r
This commit is contained in:
Nuno Campos 2024-01-02 12:16:39 -08:00 committed by GitHub
parent 8e0d5813c2
commit 9cbf14dec2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 656 additions and 94 deletions

View File

@ -23,6 +23,7 @@ from langchain_core.runnables.base import (
RunnableParallel, RunnableParallel,
RunnableSequence, RunnableSequence,
RunnableSerializable, RunnableSerializable,
chain,
) )
from langchain_core.runnables.branch import RunnableBranch from langchain_core.runnables.branch import RunnableBranch
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
@ -50,6 +51,7 @@ from langchain_core.runnables.utils import (
) )
__all__ = [ __all__ = [
"chain",
"AddableDict", "AddableDict",
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",

View File

@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import collections
import inspect import inspect
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context
from copy import deepcopy from copy import deepcopy
from functools import wraps from functools import wraps
from itertools import groupby, tee from itertools import groupby, tee
@ -15,6 +17,7 @@ from typing import (
AsyncIterator, AsyncIterator,
Awaitable, Awaitable,
Callable, Callable,
Coroutine,
Dict, Dict,
Generic, Generic,
Iterator, Iterator,
@ -48,6 +51,7 @@ from langchain_core.runnables.config import (
merge_configs, merge_configs,
patch_config, patch_config,
run_in_executor, run_in_executor,
var_child_runnable_config,
) )
from langchain_core.runnables.graph import Graph from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
@ -58,6 +62,7 @@ from langchain_core.runnables.utils import (
Input, Input,
Output, Output,
accepts_config, accepts_config,
accepts_context,
accepts_run_manager, accepts_run_manager,
gather_with_concurrency, gather_with_concurrency,
get_function_first_arg_dict_keys, get_function_first_arg_dict_keys,
@ -950,8 +955,19 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
) )
try: try:
output = call_func_with_variable_args( child_config = patch_config(config, callbacks=run_manager.get_child())
func, input, config, run_manager, **kwargs context = copy_context()
context.run(var_child_runnable_config.set, child_config)
output = cast(
Output,
context.run(
call_func_with_variable_args,
func, # type: ignore[arg-type]
input, # type: ignore[arg-type]
config,
run_manager,
**kwargs,
),
) )
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
@ -986,9 +1002,16 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
) )
try: try:
output = await acall_func_with_variable_args( child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs func, input, config, run_manager, **kwargs
) )
if accepts_context(asyncio.create_task):
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
else:
output = await coro
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise raise
@ -1178,24 +1201,29 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child())
if accepts_config(transformer): if accepts_config(transformer):
kwargs["config"] = patch_config( kwargs["config"] = child_config
config, callbacks=run_manager.get_child()
)
if accepts_run_manager(transformer): if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager kwargs["run_manager"] = run_manager
iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg] context = copy_context()
for chunk in iterator: context.run(var_child_runnable_config.set, child_config)
yield chunk iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if final_output_supported: try:
if final_output is None: while True:
final_output = chunk chunk: Output = context.run(next, iterator) # type: ignore
else: yield chunk
try: if final_output_supported:
final_output = final_output + chunk # type: ignore if final_output is None:
except TypeError: final_output = chunk
final_output = None else:
final_output_supported = False try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
except StopIteration:
pass
for ichunk in input_for_tracing: for ichunk in input_for_tracing:
if final_input_supported: if final_input_supported:
if final_input is None: if final_input is None:
@ -1254,24 +1282,35 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child())
if accepts_config(transformer): if accepts_config(transformer):
kwargs["config"] = patch_config( kwargs["config"] = child_config
config, callbacks=run_manager.get_child()
)
if accepts_run_manager(transformer): if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager kwargs["run_manager"] = run_manager
iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg] context = copy_context()
async for chunk in iterator: context.run(var_child_runnable_config.set, child_config)
yield chunk iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if final_output_supported: try:
if final_output is None: while True:
final_output = chunk if accepts_context(asyncio.create_task):
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(iterator), # type: ignore[arg-type]
context=context,
)
else: else:
try: chunk = cast(Output, await py_anext(iterator))
final_output = final_output + chunk # type: ignore yield chunk
except TypeError: if final_output_supported:
final_output = None if final_output is None:
final_output_supported = False final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
except StopAsyncIteration:
pass
async for ichunk in input_for_tracing: async for ichunk in input_for_tracing:
if final_input_supported: if final_input_supported:
if final_input is None: if final_input is None:
@ -1472,7 +1511,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
.. code-block:: python .. code-block:: python
from langchain_core.output_parsers.json import SimpleJsonOutputParser from langchain_core.output_parsers.json import SimpleJsonOutputParser
from langchain_core.chat_models.openai import ChatOpenAI from langchain.chat_models.openai import ChatOpenAI
prompt = PromptTemplate.from_template( prompt = PromptTemplate.from_template(
'In JSON format, give me a list of {topic} and their ' 'In JSON format, give me a list of {topic} and their '
@ -2482,17 +2521,25 @@ class RunnableGenerator(Runnable[Input, Output]):
) -> None: ) -> None:
if atransform is not None: if atransform is not None:
self._atransform = atransform self._atransform = atransform
func_for_name: Callable = atransform
if inspect.isasyncgenfunction(transform): if inspect.isasyncgenfunction(transform):
self._atransform = transform self._atransform = transform
func_for_name = transform
elif inspect.isgeneratorfunction(transform): elif inspect.isgeneratorfunction(transform):
self._transform = transform self._transform = transform
func_for_name = transform
else: else:
raise TypeError( raise TypeError(
"Expected a generator function type for `transform`." "Expected a generator function type for `transform`."
f"Instead got an unsupported type: {type(transform)}" f"Instead got an unsupported type: {type(transform)}"
) )
try:
self.name = func_for_name.__name__
except AttributeError:
pass
@property @property
def InputType(self) -> Any: def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform") func = getattr(self, "_transform", None) or getattr(self, "_atransform")
@ -2646,12 +2693,14 @@ class RunnableLambda(Runnable[Input, Output]):
func: Union[ func: Union[
Union[ Union[
Callable[[Input], Output], Callable[[Input], Output],
Callable[[Input], Iterator[Output]],
Callable[[Input, RunnableConfig], Output], Callable[[Input, RunnableConfig], Output],
Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
], ],
Union[ Union[
Callable[[Input], Awaitable[Output]], Callable[[Input], Awaitable[Output]],
Callable[[Input], AsyncIterator[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]], Callable[[Input, RunnableConfig], Awaitable[Output]],
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
Callable[ Callable[
@ -2663,6 +2712,7 @@ class RunnableLambda(Runnable[Input, Output]):
afunc: Optional[ afunc: Optional[
Union[ Union[
Callable[[Input], Awaitable[Output]], Callable[[Input], Awaitable[Output]],
Callable[[Input], AsyncIterator[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]], Callable[[Input, RunnableConfig], Awaitable[Output]],
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
Callable[ Callable[
@ -2685,7 +2735,7 @@ class RunnableLambda(Runnable[Input, Output]):
self.afunc = afunc self.afunc = afunc
func_for_name: Callable = afunc func_for_name: Callable = afunc
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
if afunc is not None: if afunc is not None:
raise TypeError( raise TypeError(
"Func was provided as a coroutine function, but afunc was " "Func was provided as a coroutine function, but afunc was "
@ -2767,11 +2817,16 @@ class RunnableLambda(Runnable[Input, Output]):
func = getattr(self, "func", None) or getattr(self, "afunc") func = getattr(self, "func", None) or getattr(self, "afunc")
try: try:
sig = inspect.signature(func) sig = inspect.signature(func)
return ( if sig.return_annotation != inspect.Signature.empty:
sig.return_annotation # unwrap iterator types
if sig.return_annotation != inspect.Signature.empty if getattr(sig.return_annotation, "__origin__", None) in (
else Any collections.abc.Iterator,
) collections.abc.AsyncIterator,
):
return getattr(sig.return_annotation, "__args__", (Any,))[0]
return sig.return_annotation
else:
return Any
except ValueError: except ValueError:
return Any return Any
@ -2848,9 +2903,26 @@ class RunnableLambda(Runnable[Input, Output]):
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Output: ) -> Output:
output = call_func_with_variable_args( if inspect.isgeneratorfunction(self.func):
self.func, input, config, run_manager, **kwargs output: Optional[Output] = None
) for chunk in call_func_with_variable_args(
cast(Callable[[Input], Iterator[Output]], self.func),
input,
config,
run_manager,
**kwargs,
):
if output is None:
output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
output = chunk
else:
output = call_func_with_variable_args(
self.func, input, config, run_manager, **kwargs
)
# If the output is a runnable, invoke it # If the output is a runnable, invoke it
if isinstance(output, Runnable): if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"] recursion_limit = config["recursion_limit"]
@ -2866,7 +2938,7 @@ class RunnableLambda(Runnable[Input, Output]):
recursion_limit=recursion_limit - 1, recursion_limit=recursion_limit - 1,
), ),
) )
return output return cast(Output, output)
async def _ainvoke( async def _ainvoke(
self, self,
@ -2878,16 +2950,69 @@ class RunnableLambda(Runnable[Input, Output]):
if hasattr(self, "afunc"): if hasattr(self, "afunc"):
afunc = self.afunc afunc = self.afunc
else: else:
if inspect.isgeneratorfunction(self.func):
@wraps(self.func) def func(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output: Optional[Output] = None
for chunk in call_func_with_variable_args(
cast(Callable[[Input], Iterator[Output]], self.func),
input,
config,
run_manager.get_sync(),
**kwargs,
):
if output is None:
output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
output = chunk
return cast(Output, output)
else:
def func(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs
)
@wraps(func)
async def f(*args, **kwargs): # type: ignore[no-untyped-def] async def f(*args, **kwargs): # type: ignore[no-untyped-def]
return await run_in_executor(config, self.func, *args, **kwargs) return await run_in_executor(config, func, *args, **kwargs)
afunc = f afunc = f
output = await acall_func_with_variable_args( if inspect.isasyncgenfunction(afunc):
afunc, input, config, run_manager, **kwargs output: Optional[Output] = None
) async for chunk in cast(
AsyncIterator[Output],
acall_func_with_variable_args(
cast(Callable, afunc),
input,
config,
run_manager,
**kwargs,
),
):
if output is None:
output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
output = chunk
else:
output = await acall_func_with_variable_args(
cast(Callable, afunc), input, config, run_manager, **kwargs
)
# If the output is a runnable, invoke it # If the output is a runnable, invoke it
if isinstance(output, Runnable): if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"] recursion_limit = config["recursion_limit"]
@ -2903,7 +3028,7 @@ class RunnableLambda(Runnable[Input, Output]):
recursion_limit=recursion_limit - 1, recursion_limit=recursion_limit - 1,
), ),
) )
return output return cast(Output, output)
def _config( def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any] self, config: Optional[RunnableConfig], callable: Callable[..., Any]
@ -2972,9 +3097,23 @@ class RunnableLambda(Runnable[Input, Output]):
except TypeError: except TypeError:
final = ichunk final = ichunk
output = call_func_with_variable_args( if inspect.isgeneratorfunction(self.func):
self.func, cast(Input, final), config, run_manager, **kwargs output: Optional[Output] = None
) for chunk in call_func_with_variable_args(
self.func, cast(Input, final), config, run_manager, **kwargs
):
yield chunk
if output is None:
output = chunk
else:
try:
output = output + chunk
except TypeError:
output = chunk
else:
output = call_func_with_variable_args(
self.func, cast(Input, final), config, run_manager, **kwargs
)
# If the output is a runnable, use its stream output # If the output is a runnable, use its stream output
if isinstance(output, Runnable): if isinstance(output, Runnable):
@ -2993,9 +3132,9 @@ class RunnableLambda(Runnable[Input, Output]):
), ),
): ):
yield chunk yield chunk
else: elif not inspect.isgeneratorfunction(self.func):
# Otherwise, just yield it # Otherwise, just yield it
yield output yield cast(Output, output)
def transform( def transform(
self, self,
@ -3030,6 +3169,7 @@ class RunnableLambda(Runnable[Input, Output]):
input: AsyncIterator[Input], input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
final: Optional[Input] = None final: Optional[Input] = None
async for ichunk in input: async for ichunk in input:
@ -3044,16 +3184,51 @@ class RunnableLambda(Runnable[Input, Output]):
if hasattr(self, "afunc"): if hasattr(self, "afunc"):
afunc = self.afunc afunc = self.afunc
else: else:
if inspect.isgeneratorfunction(self.func):
raise TypeError(
"Cannot stream from a generator function asynchronously."
"Use .stream() instead."
)
@wraps(self.func) def func(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs
)
@wraps(func)
async def f(*args, **kwargs): # type: ignore[no-untyped-def] async def f(*args, **kwargs): # type: ignore[no-untyped-def]
return await run_in_executor(config, self.func, *args, **kwargs) return await run_in_executor(config, func, *args, **kwargs)
afunc = f afunc = f
output = await acall_func_with_variable_args( if inspect.isasyncgenfunction(afunc):
afunc, cast(Input, final), config, run_manager output: Optional[Output] = None
) async for chunk in cast(
AsyncIterator[Output],
acall_func_with_variable_args(
cast(Callable, afunc),
cast(Input, final),
config,
run_manager,
**kwargs,
),
):
yield chunk
if output is None:
output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
output = chunk
else:
output = await acall_func_with_variable_args(
cast(Callable, afunc), cast(Input, final), config, run_manager, **kwargs
)
# If the output is a runnable, use its astream output # If the output is a runnable, use its astream output
if isinstance(output, Runnable): if isinstance(output, Runnable):
@ -3072,9 +3247,9 @@ class RunnableLambda(Runnable[Input, Output]):
), ),
): ):
yield chunk yield chunk
else: elif not inspect.isasyncgenfunction(afunc):
# Otherwise, just yield it # Otherwise, just yield it
yield output yield cast(Output, output)
async def atransform( async def atransform(
self, self,
@ -3699,3 +3874,69 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
f"Expected a Runnable, callable or dict." f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}" f"Instead got an unsupported type: {type(thing)}"
) )
@overload
def chain(
func: Callable[[Input], Coroutine[Any, Any, Output]],
) -> Runnable[Input, Output]:
...
@overload
def chain(
func: Callable[[Input], Iterator[Output]],
) -> Runnable[Input, Output]:
...
@overload
def chain(
func: Callable[[Input], AsyncIterator[Output]],
) -> Runnable[Input, Output]:
...
@overload
def chain(
func: Callable[[Input], Output],
) -> Runnable[Input, Output]:
...
def chain(
func: Union[
Callable[[Input], Output],
Callable[[Input], Iterator[Output]],
Callable[[Input], Coroutine[Any, Any, Output]],
Callable[[Input], AsyncIterator[Output]],
],
) -> Runnable[Input, Output]:
"""Decorate a function to make it a Runnable.
Sets the name of the runnable to the name of the function.
Any runnables called by the function will be traced as dependencies.
Args:
func: A callable.
Returns:
A Runnable.
Example:
.. code-block:: python
from langchain_core.runnables import chain
from langchain_core.prompts import PromptTemplate
from langchain.llms import OpenAI
@chain
def my_func(fields):
prompt = PromptTemplate("Hello, {name}!")
llm = OpenAI()
formatted = prompt.invoke(**fields)
for chunk in llm.stream(formatted):
yield chunk
"""
return RunnableLambda(func)

View File

@ -323,7 +323,7 @@ def call_func_with_variable_args(
return func(input, **kwargs) # type: ignore[call-arg] return func(input, **kwargs) # type: ignore[call-arg]
async def acall_func_with_variable_args( def acall_func_with_variable_args(
func: Union[ func: Union[
Callable[[Input], Awaitable[Output]], Callable[[Input], Awaitable[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]], Callable[[Input, RunnableConfig], Awaitable[Output]],
@ -337,7 +337,7 @@ async def acall_func_with_variable_args(
config: RunnableConfig, config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Output: ) -> Awaitable[Output]:
"""Call function that may optionally accept a run_manager and/or config. """Call function that may optionally accept a run_manager and/or config.
Args: Args:
@ -361,7 +361,7 @@ async def acall_func_with_variable_args(
kwargs["config"] = config kwargs["config"] = config
if run_manager is not None and accepts_run_manager(func): if run_manager is not None and accepts_run_manager(func):
kwargs["run_manager"] = run_manager kwargs["run_manager"] = run_manager
return await func(input, **kwargs) # type: ignore[call-arg] return func(input, **kwargs) # type: ignore[call-arg]
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:

View File

@ -68,6 +68,14 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
return False return False
def accepts_context(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a context argument."""
try:
return signature(callable).parameters.get("context") is not None
except ValueError:
return False
class IsLocalDict(ast.NodeVisitor): class IsLocalDict(ast.NodeVisitor):
"""Check if a name is a local dict.""" """Check if a name is a local dict."""

View File

@ -12,7 +12,7 @@ from langchain_core.runnables.utils import aadd, add
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
class TestCase(NamedTuple): class _TestCase(NamedTuple):
input: Any input: Any
output: Any output: Any
@ -102,22 +102,22 @@ test_cases = [
( (
Context.setter("foo") | Context.getter("foo"), Context.setter("foo") | Context.getter("foo"),
( (
TestCase("foo", "foo"), _TestCase("foo", "foo"),
TestCase("bar", "bar"), _TestCase("bar", "bar"),
), ),
), ),
( (
Context.setter("input") | {"bar": Context.getter("input")}, Context.setter("input") | {"bar": Context.getter("input")},
( (
TestCase("foo", {"bar": "foo"}), _TestCase("foo", {"bar": "foo"}),
TestCase("bar", {"bar": "bar"}), _TestCase("bar", {"bar": "bar"}),
), ),
), ),
( (
{"bar": Context.setter("input")} | Context.getter("input"), {"bar": Context.setter("input")} | Context.getter("input"),
( (
TestCase("foo", "foo"), _TestCase("foo", "foo"),
TestCase("bar", "bar"), _TestCase("bar", "bar"),
), ),
), ),
( (
@ -132,11 +132,11 @@ test_cases = [
} }
), ),
( (
TestCase( _TestCase(
{"foo": "foo", "bar": "bar"}, {"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")}, {"response": "hello", "prompt": StringPromptValue(text="foo bar")},
), ),
TestCase( _TestCase(
{"foo": "bar", "bar": "foo"}, {"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")}, {"response": "hello", "prompt": StringPromptValue(text="bar foo")},
), ),
@ -155,7 +155,7 @@ test_cases = [
} }
), ),
( (
TestCase( _TestCase(
{"foo": "foo", "bar": "bar"}, {"foo": "foo", "bar": "bar"},
{ {
"response": "hello", "response": "hello",
@ -163,7 +163,7 @@ test_cases = [
"prompt_str": "foo bar", "prompt_str": "foo bar",
}, },
), ),
TestCase( _TestCase(
{"foo": "bar", "bar": "foo"}, {"foo": "bar", "bar": "foo"},
{ {
"response": "hello", "response": "hello",
@ -185,11 +185,11 @@ test_cases = [
} }
), ),
( (
TestCase( _TestCase(
{"foo": "foo", "bar": "bar"}, {"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"}, {"response": "hello", "prompt_str": "foo bar"},
), ),
TestCase( _TestCase(
{"foo": "bar", "bar": "foo"}, {"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"}, {"response": "hello", "prompt_str": "bar foo"},
), ),
@ -207,11 +207,11 @@ test_cases = [
} }
), ),
( (
TestCase( _TestCase(
{"foo": "foo", "bar": "bar"}, {"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"}, {"response": "hello", "prompt_str": "foo bar"},
), ),
TestCase( _TestCase(
{"foo": "bar", "bar": "foo"}, {"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"}, {"response": "hello", "prompt_str": "bar foo"},
), ),
@ -229,11 +229,11 @@ test_cases = [
} }
), ),
( (
TestCase( _TestCase(
{"foo": "foo", "bar": "bar"}, {"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")}, {"response": "hello", "prompt": StringPromptValue(text="foo bar")},
), ),
TestCase( _TestCase(
{"foo": "bar", "bar": "foo"}, {"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")}, {"response": "hello", "prompt": StringPromptValue(text="bar foo")},
), ),
@ -242,7 +242,7 @@ test_cases = [
( (
seq_naive_rag, seq_naive_rag,
( (
TestCase( _TestCase(
"What up", "What up",
{ {
"result": "hello", "result": "hello",
@ -254,7 +254,7 @@ test_cases = [
"input": "What up", "input": "What up",
}, },
), ),
TestCase( _TestCase(
"Howdy", "Howdy",
{ {
"result": "hello", "result": "hello",
@ -271,7 +271,7 @@ test_cases = [
( (
seq_naive_rag_alt, seq_naive_rag_alt,
( (
TestCase( _TestCase(
"What up", "What up",
{ {
"result": "hello", "result": "hello",
@ -283,7 +283,7 @@ test_cases = [
"input": "What up", "input": "What up",
}, },
), ),
TestCase( _TestCase(
"Howdy", "Howdy",
{ {
"result": "hello", "result": "hello",
@ -300,7 +300,7 @@ test_cases = [
( (
seq_naive_rag_scoped, seq_naive_rag_scoped,
( (
TestCase( _TestCase(
"What up", "What up",
{ {
"result": "hello", "result": "hello",
@ -312,7 +312,7 @@ test_cases = [
"input": "What up", "input": "What up",
}, },
), ),
TestCase( _TestCase(
"Howdy", "Howdy",
{ {
"result": "hello", "result": "hello",
@ -331,7 +331,7 @@ test_cases = [
@pytest.mark.parametrize("runnable, cases", test_cases) @pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables( async def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[TestCase] runnable: Union[Runnable, Callable[[], Runnable]], cases: List[_TestCase]
) -> None: ) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable() runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output assert runnable.invoke(cases[0].input) == cases[0].output

View File

@ -1,6 +1,7 @@
from langchain_core.runnables import __all__ from langchain_core.runnables import __all__
EXPECTED_ALL = [ EXPECTED_ALL = [
"chain",
"AddableDict", "AddableDict",
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",

View File

@ -68,6 +68,7 @@ from langchain_core.runnables import (
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableWithFallbacks,
add, add,
chain,
) )
from langchain_core.tools import BaseTool, tool from langchain_core.tools import BaseTool, tool
from langchain_core.tracers import ( from langchain_core.tracers import (
@ -4388,9 +4389,9 @@ async def test_runnable_gen() -> None:
runnable = RunnableGenerator(gen) runnable = RunnableGenerator(gen)
assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"} assert runnable.input_schema.schema() == {"title": "gen_input"}
assert runnable.output_schema.schema() == { assert runnable.output_schema.schema() == {
"title": "RunnableGeneratorOutput", "title": "gen_output",
"type": "integer", "type": "integer",
} }
@ -4410,6 +4411,315 @@ async def test_runnable_gen() -> None:
assert await arunnable.abatch([None, None]) == [6, 6] assert await arunnable.abatch([None, None]) == [6, 6]
async def test_runnable_gen_context_config() -> None:
"""Test that a generator can call other runnables with config
propagated from the context."""
fake = RunnableLambda(len)
def gen(input: Iterator[Any]) -> Iterator[int]:
yield fake.invoke("a")
yield fake.invoke("aa")
yield fake.invoke("aaa")
runnable = RunnableGenerator(gen)
assert runnable.input_schema.schema() == {"title": "gen_input"}
assert runnable.output_schema.schema() == {
"title": "gen_output",
"type": "integer",
}
tracer = FakeTracer()
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer.runs.clear()
assert list(runnable.stream(None)) == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
assert len(tracer.runs[1].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
if sys.version_info < (3, 11):
# Python 3.10 and below don't support running async tasks in a specific context
return
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
yield await fake.ainvoke("a")
yield await fake.ainvoke("aa")
yield await fake.ainvoke("aaa")
arunnable = RunnableGenerator(agen)
tracer = FakeTracer()
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer.runs.clear()
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [
1,
2,
3,
]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
assert len(tracer.runs[1].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
async def test_runnable_iter_context_config() -> None:
"""Test that a generator can call other runnables with config
propagated from the context."""
fake = RunnableLambda(len)
@chain
def gen(input: str) -> Iterator[int]:
yield fake.invoke(input)
yield fake.invoke(input * 2)
yield fake.invoke(input * 3)
assert gen.input_schema.schema() == {
"title": "gen_input",
"type": "string",
}
assert gen.output_schema.schema() == {
"title": "gen_output",
"type": "integer",
}
tracer = FakeTracer()
assert gen.invoke("a", {"callbacks": [tracer]}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer.runs.clear()
assert list(gen.stream("a")) == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert list(gen.stream("a", {"callbacks": [tracer]})) == [1, 2, 3]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert gen.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
assert len(tracer.runs[1].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
if sys.version_info < (3, 11):
# Python 3.10 and below don't support running async tasks in a specific context
return
@chain
async def agen(input: str) -> AsyncIterator[int]:
yield await fake.ainvoke(input)
yield await fake.ainvoke(input * 2)
yield await fake.ainvoke(input * 3)
assert agen.input_schema.schema() == {
"title": "agen_input",
"type": "string",
}
assert agen.output_schema.schema() == {
"title": "agen_output",
"type": "integer",
}
tracer = FakeTracer()
assert await agen.ainvoke("a", {"callbacks": [tracer]}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer.runs.clear()
assert [p async for p in agen.astream("a")] == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert [p async for p in agen.astream("a", {"callbacks": [tracer]})] == [
1,
2,
3,
]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
assert len(tracer.runs[1].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
async def test_runnable_lambda_context_config() -> None:
"""Test that a function can call other runnables with config
propagated from the context."""
fake = RunnableLambda(len)
@chain
def fun(input: str) -> int:
output = fake.invoke(input)
output += fake.invoke(input * 2)
output += fake.invoke(input * 3)
return output
assert fun.input_schema.schema() == {"title": "fun_input", "type": "string"}
assert fun.output_schema.schema() == {
"title": "fun_output",
"type": "integer",
}
tracer = FakeTracer()
assert fun.invoke("a", {"callbacks": [tracer]}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer.runs.clear()
assert list(fun.stream("a")) == [6]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert list(fun.stream("a", {"callbacks": [tracer]})) == [6]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert fun.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
assert len(tracer.runs[1].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
if sys.version_info < (3, 11):
# Python 3.10 and below don't support running async tasks in a specific context
return
@chain
async def afun(input: str) -> int:
output = await fake.ainvoke(input)
output += await fake.ainvoke(input * 2)
output += await fake.ainvoke(input * 3)
return output
assert afun.input_schema.schema() == {"title": "afun_input", "type": "string"}
assert afun.output_schema.schema() == {
"title": "afun_output",
"type": "integer",
}
tracer = FakeTracer()
assert await afun.ainvoke("a", {"callbacks": [tracer]}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer.runs.clear()
assert [p async for p in afun.astream("a")] == [6]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert [p async for p in afun.astream("a", {"callbacks": [tracer]})] == [6]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert await afun.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
assert len(tracer.runs[1].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
async def test_runnable_gen_transform() -> None: async def test_runnable_gen_transform() -> None:
"""Test that a generator can be used as a runnable.""" """Test that a generator can be used as a runnable."""
@ -4434,19 +4744,19 @@ async def test_runnable_gen_transform() -> None:
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
assert chain.input_schema.schema() == { assert chain.input_schema.schema() == {
"title": "RunnableGeneratorInput", "title": "gen_indexes_input",
"type": "integer", "type": "integer",
} }
assert chain.output_schema.schema() == { assert chain.output_schema.schema() == {
"title": "RunnableGeneratorOutput", "title": "plus_one_output",
"type": "integer", "type": "integer",
} }
assert achain.input_schema.schema() == { assert achain.input_schema.schema() == {
"title": "RunnableGeneratorInput", "title": "gen_indexes_input",
"type": "integer", "type": "integer",
} }
assert achain.output_schema.schema() == { assert achain.output_schema.schema() == {
"title": "RunnableGeneratorOutput", "title": "aplus_one_output",
"type": "integer", "type": "integer",
} }