mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
Fetch runnable config from context var inside runnable lambda and runnable generator (#15334)
- easier to write custom logic/loops with automatic tracing - if you don't want to streaming support write a regular function and pass to RunnableLambda - if you do want streaming write a generator and pass it to RunnableGenerator ```py import json from typing import AsyncIterator from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage from langchain_core.agents import AgentAction, AgentFinish from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable, RunnableGenerator, RunnablePassthrough from langchain_core.tools import BaseTool from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.chat_models import ChatOpenAI from langchain.tools.render import format_tool_to_openai_function def _get_tavily(): from langchain.tools.tavily_search import TavilySearchResults from langchain.utilities.tavily_search import TavilySearchAPIWrapper tavily_search = TavilySearchAPIWrapper() return TavilySearchResults(api_wrapper=tavily_search) async def _agent_executor_generator( input: AsyncIterator[list[BaseMessage]], *, max_iterations: int = 10, tools: dict[str, BaseTool], agent: Runnable[list[BaseMessage], BaseMessage], parser: Runnable[BaseMessage, AgentAction | AgentFinish], ) -> AsyncIterator[BaseMessage]: messages = [m async for mm in input for m in mm] for _ in range(max_iterations): next_message = await agent.ainvoke(messages) yield next_message messages.append(next_message) parsed = await parser.ainvoke(next_message) if isinstance(parsed, AgentAction): result = await tools[parsed.tool].ainvoke(parsed.tool_input) next_message = FunctionMessage(name=parsed.tool, content=json.dumps(result)) yield next_message messages.append(next_message) elif isinstance(parsed, AgentFinish): return def get_agent_executor(tools: list[BaseTool], system_message: str): llm = ChatOpenAI(model="gpt-4-1106-preview", temperature=0, streaming=True) prompt = ChatPromptTemplate.from_messages( [ ("system", system_message), MessagesPlaceholder(variable_name="messages"), ] ) llm_with_tools = llm.bind( functions=[format_tool_to_openai_function(t) for t in tools] ) agent = {"messages": RunnablePassthrough()} | prompt | llm_with_tools parser = OpenAIFunctionsAgentOutputParser() executor = RunnableGenerator(_agent_executor_generator) return executor.bind( tools={tool.name for tool in tools}, agent=agent, parser=parser ) agent = get_agent_executor([_get_tavily()], "You are a very nice agent!") async def main(): async for message in agent.astream( [HumanMessage(content="whats the weather in sf tomorrow?")] ): print(message) if __name__ == "__main__": import asyncio asyncio.run(main()) ``` results in this trace https://smith.langchain.com/public/fa17f05d-9724-4d08-8fa1-750f8fcd051b/r
This commit is contained in:
parent
8e0d5813c2
commit
9cbf14dec2
@ -23,6 +23,7 @@ from langchain_core.runnables.base import (
|
|||||||
RunnableParallel,
|
RunnableParallel,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
|
chain,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.branch import RunnableBranch
|
from langchain_core.runnables.branch import RunnableBranch
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
@ -50,6 +51,7 @@ from langchain_core.runnables.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"chain",
|
||||||
"AddableDict",
|
"AddableDict",
|
||||||
"ConfigurableField",
|
"ConfigurableField",
|
||||||
"ConfigurableFieldSingleOption",
|
"ConfigurableFieldSingleOption",
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import FIRST_COMPLETED, wait
|
from concurrent.futures import FIRST_COMPLETED, wait
|
||||||
|
from contextvars import copy_context
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import groupby, tee
|
from itertools import groupby, tee
|
||||||
@ -15,6 +17,7 @@ from typing import (
|
|||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
Iterator,
|
||||||
@ -48,6 +51,7 @@ from langchain_core.runnables.config import (
|
|||||||
merge_configs,
|
merge_configs,
|
||||||
patch_config,
|
patch_config,
|
||||||
run_in_executor,
|
run_in_executor,
|
||||||
|
var_child_runnable_config,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.graph import Graph
|
from langchain_core.runnables.graph import Graph
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
@ -58,6 +62,7 @@ from langchain_core.runnables.utils import (
|
|||||||
Input,
|
Input,
|
||||||
Output,
|
Output,
|
||||||
accepts_config,
|
accepts_config,
|
||||||
|
accepts_context,
|
||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
get_function_first_arg_dict_keys,
|
get_function_first_arg_dict_keys,
|
||||||
@ -950,8 +955,19 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = call_func_with_variable_args(
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
func, input, config, run_manager, **kwargs
|
context = copy_context()
|
||||||
|
context.run(var_child_runnable_config.set, child_config)
|
||||||
|
output = cast(
|
||||||
|
Output,
|
||||||
|
context.run(
|
||||||
|
call_func_with_variable_args,
|
||||||
|
func, # type: ignore[arg-type]
|
||||||
|
input, # type: ignore[arg-type]
|
||||||
|
config,
|
||||||
|
run_manager,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
@ -986,9 +1002,16 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = await acall_func_with_variable_args(
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
|
context = copy_context()
|
||||||
|
context.run(var_child_runnable_config.set, child_config)
|
||||||
|
coro = acall_func_with_variable_args(
|
||||||
func, input, config, run_manager, **kwargs
|
func, input, config, run_manager, **kwargs
|
||||||
)
|
)
|
||||||
|
if accepts_context(asyncio.create_task):
|
||||||
|
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
|
else:
|
||||||
|
output = await coro
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
@ -1178,24 +1201,29 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
if accepts_config(transformer):
|
if accepts_config(transformer):
|
||||||
kwargs["config"] = patch_config(
|
kwargs["config"] = child_config
|
||||||
config, callbacks=run_manager.get_child()
|
|
||||||
)
|
|
||||||
if accepts_run_manager(transformer):
|
if accepts_run_manager(transformer):
|
||||||
kwargs["run_manager"] = run_manager
|
kwargs["run_manager"] = run_manager
|
||||||
iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg]
|
context = copy_context()
|
||||||
for chunk in iterator:
|
context.run(var_child_runnable_config.set, child_config)
|
||||||
yield chunk
|
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
||||||
if final_output_supported:
|
try:
|
||||||
if final_output is None:
|
while True:
|
||||||
final_output = chunk
|
chunk: Output = context.run(next, iterator) # type: ignore
|
||||||
else:
|
yield chunk
|
||||||
try:
|
if final_output_supported:
|
||||||
final_output = final_output + chunk # type: ignore
|
if final_output is None:
|
||||||
except TypeError:
|
final_output = chunk
|
||||||
final_output = None
|
else:
|
||||||
final_output_supported = False
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = None
|
||||||
|
final_output_supported = False
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
for ichunk in input_for_tracing:
|
for ichunk in input_for_tracing:
|
||||||
if final_input_supported:
|
if final_input_supported:
|
||||||
if final_input is None:
|
if final_input is None:
|
||||||
@ -1254,24 +1282,35 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
if accepts_config(transformer):
|
if accepts_config(transformer):
|
||||||
kwargs["config"] = patch_config(
|
kwargs["config"] = child_config
|
||||||
config, callbacks=run_manager.get_child()
|
|
||||||
)
|
|
||||||
if accepts_run_manager(transformer):
|
if accepts_run_manager(transformer):
|
||||||
kwargs["run_manager"] = run_manager
|
kwargs["run_manager"] = run_manager
|
||||||
iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg]
|
context = copy_context()
|
||||||
async for chunk in iterator:
|
context.run(var_child_runnable_config.set, child_config)
|
||||||
yield chunk
|
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
||||||
if final_output_supported:
|
try:
|
||||||
if final_output is None:
|
while True:
|
||||||
final_output = chunk
|
if accepts_context(asyncio.create_task):
|
||||||
|
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
||||||
|
py_anext(iterator), # type: ignore[arg-type]
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
chunk = cast(Output, await py_anext(iterator))
|
||||||
final_output = final_output + chunk # type: ignore
|
yield chunk
|
||||||
except TypeError:
|
if final_output_supported:
|
||||||
final_output = None
|
if final_output is None:
|
||||||
final_output_supported = False
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = None
|
||||||
|
final_output_supported = False
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
async for ichunk in input_for_tracing:
|
async for ichunk in input_for_tracing:
|
||||||
if final_input_supported:
|
if final_input_supported:
|
||||||
if final_input is None:
|
if final_input is None:
|
||||||
@ -1472,7 +1511,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_core.output_parsers.json import SimpleJsonOutputParser
|
from langchain_core.output_parsers.json import SimpleJsonOutputParser
|
||||||
from langchain_core.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
|
||||||
prompt = PromptTemplate.from_template(
|
prompt = PromptTemplate.from_template(
|
||||||
'In JSON format, give me a list of {topic} and their '
|
'In JSON format, give me a list of {topic} and their '
|
||||||
@ -2482,17 +2521,25 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if atransform is not None:
|
if atransform is not None:
|
||||||
self._atransform = atransform
|
self._atransform = atransform
|
||||||
|
func_for_name: Callable = atransform
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(transform):
|
if inspect.isasyncgenfunction(transform):
|
||||||
self._atransform = transform
|
self._atransform = transform
|
||||||
|
func_for_name = transform
|
||||||
elif inspect.isgeneratorfunction(transform):
|
elif inspect.isgeneratorfunction(transform):
|
||||||
self._transform = transform
|
self._transform = transform
|
||||||
|
func_for_name = transform
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Expected a generator function type for `transform`."
|
"Expected a generator function type for `transform`."
|
||||||
f"Instead got an unsupported type: {type(transform)}"
|
f"Instead got an unsupported type: {type(transform)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.name = func_for_name.__name__
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
||||||
@ -2646,12 +2693,14 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
func: Union[
|
func: Union[
|
||||||
Union[
|
Union[
|
||||||
Callable[[Input], Output],
|
Callable[[Input], Output],
|
||||||
|
Callable[[Input], Iterator[Output]],
|
||||||
Callable[[Input, RunnableConfig], Output],
|
Callable[[Input, RunnableConfig], Output],
|
||||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||||
],
|
],
|
||||||
Union[
|
Union[
|
||||||
Callable[[Input], Awaitable[Output]],
|
Callable[[Input], Awaitable[Output]],
|
||||||
|
Callable[[Input], AsyncIterator[Output]],
|
||||||
Callable[[Input, RunnableConfig], Awaitable[Output]],
|
Callable[[Input, RunnableConfig], Awaitable[Output]],
|
||||||
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
||||||
Callable[
|
Callable[
|
||||||
@ -2663,6 +2712,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
afunc: Optional[
|
afunc: Optional[
|
||||||
Union[
|
Union[
|
||||||
Callable[[Input], Awaitable[Output]],
|
Callable[[Input], Awaitable[Output]],
|
||||||
|
Callable[[Input], AsyncIterator[Output]],
|
||||||
Callable[[Input, RunnableConfig], Awaitable[Output]],
|
Callable[[Input, RunnableConfig], Awaitable[Output]],
|
||||||
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
||||||
Callable[
|
Callable[
|
||||||
@ -2685,7 +2735,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
self.afunc = afunc
|
self.afunc = afunc
|
||||||
func_for_name: Callable = afunc
|
func_for_name: Callable = afunc
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
|
||||||
if afunc is not None:
|
if afunc is not None:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Func was provided as a coroutine function, but afunc was "
|
"Func was provided as a coroutine function, but afunc was "
|
||||||
@ -2767,11 +2817,16 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||||
try:
|
try:
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
return (
|
if sig.return_annotation != inspect.Signature.empty:
|
||||||
sig.return_annotation
|
# unwrap iterator types
|
||||||
if sig.return_annotation != inspect.Signature.empty
|
if getattr(sig.return_annotation, "__origin__", None) in (
|
||||||
else Any
|
collections.abc.Iterator,
|
||||||
)
|
collections.abc.AsyncIterator,
|
||||||
|
):
|
||||||
|
return getattr(sig.return_annotation, "__args__", (Any,))[0]
|
||||||
|
return sig.return_annotation
|
||||||
|
else:
|
||||||
|
return Any
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return Any
|
return Any
|
||||||
|
|
||||||
@ -2848,9 +2903,26 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Output:
|
) -> Output:
|
||||||
output = call_func_with_variable_args(
|
if inspect.isgeneratorfunction(self.func):
|
||||||
self.func, input, config, run_manager, **kwargs
|
output: Optional[Output] = None
|
||||||
)
|
for chunk in call_func_with_variable_args(
|
||||||
|
cast(Callable[[Input], Iterator[Output]], self.func),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_manager,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if output is None:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
output = output + chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
output = call_func_with_variable_args(
|
||||||
|
self.func, input, config, run_manager, **kwargs
|
||||||
|
)
|
||||||
# If the output is a runnable, invoke it
|
# If the output is a runnable, invoke it
|
||||||
if isinstance(output, Runnable):
|
if isinstance(output, Runnable):
|
||||||
recursion_limit = config["recursion_limit"]
|
recursion_limit = config["recursion_limit"]
|
||||||
@ -2866,7 +2938,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
recursion_limit=recursion_limit - 1,
|
recursion_limit=recursion_limit - 1,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return output
|
return cast(Output, output)
|
||||||
|
|
||||||
async def _ainvoke(
|
async def _ainvoke(
|
||||||
self,
|
self,
|
||||||
@ -2878,16 +2950,69 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
if hasattr(self, "afunc"):
|
if hasattr(self, "afunc"):
|
||||||
afunc = self.afunc
|
afunc = self.afunc
|
||||||
else:
|
else:
|
||||||
|
if inspect.isgeneratorfunction(self.func):
|
||||||
|
|
||||||
@wraps(self.func)
|
def func(
|
||||||
|
input: Input,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
output: Optional[Output] = None
|
||||||
|
for chunk in call_func_with_variable_args(
|
||||||
|
cast(Callable[[Input], Iterator[Output]], self.func),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_manager.get_sync(),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if output is None:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
output = output + chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
output = chunk
|
||||||
|
return cast(Output, output)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def func(
|
||||||
|
input: Input,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
return call_func_with_variable_args(
|
||||||
|
self.func, input, config, run_manager.get_sync(), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
return await run_in_executor(config, self.func, *args, **kwargs)
|
return await run_in_executor(config, func, *args, **kwargs)
|
||||||
|
|
||||||
afunc = f
|
afunc = f
|
||||||
|
|
||||||
output = await acall_func_with_variable_args(
|
if inspect.isasyncgenfunction(afunc):
|
||||||
afunc, input, config, run_manager, **kwargs
|
output: Optional[Output] = None
|
||||||
)
|
async for chunk in cast(
|
||||||
|
AsyncIterator[Output],
|
||||||
|
acall_func_with_variable_args(
|
||||||
|
cast(Callable, afunc),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_manager,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
if output is None:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
output = output + chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
output = await acall_func_with_variable_args(
|
||||||
|
cast(Callable, afunc), input, config, run_manager, **kwargs
|
||||||
|
)
|
||||||
# If the output is a runnable, invoke it
|
# If the output is a runnable, invoke it
|
||||||
if isinstance(output, Runnable):
|
if isinstance(output, Runnable):
|
||||||
recursion_limit = config["recursion_limit"]
|
recursion_limit = config["recursion_limit"]
|
||||||
@ -2903,7 +3028,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
recursion_limit=recursion_limit - 1,
|
recursion_limit=recursion_limit - 1,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return output
|
return cast(Output, output)
|
||||||
|
|
||||||
def _config(
|
def _config(
|
||||||
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
|
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
|
||||||
@ -2972,9 +3097,23 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
final = ichunk
|
final = ichunk
|
||||||
|
|
||||||
output = call_func_with_variable_args(
|
if inspect.isgeneratorfunction(self.func):
|
||||||
self.func, cast(Input, final), config, run_manager, **kwargs
|
output: Optional[Output] = None
|
||||||
)
|
for chunk in call_func_with_variable_args(
|
||||||
|
self.func, cast(Input, final), config, run_manager, **kwargs
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
if output is None:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
output = output + chunk
|
||||||
|
except TypeError:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
output = call_func_with_variable_args(
|
||||||
|
self.func, cast(Input, final), config, run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# If the output is a runnable, use its stream output
|
# If the output is a runnable, use its stream output
|
||||||
if isinstance(output, Runnable):
|
if isinstance(output, Runnable):
|
||||||
@ -2993,9 +3132,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
elif not inspect.isgeneratorfunction(self.func):
|
||||||
# Otherwise, just yield it
|
# Otherwise, just yield it
|
||||||
yield output
|
yield cast(Output, output)
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
@ -3030,6 +3169,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
input: AsyncIterator[Input],
|
input: AsyncIterator[Input],
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
final: Optional[Input] = None
|
final: Optional[Input] = None
|
||||||
async for ichunk in input:
|
async for ichunk in input:
|
||||||
@ -3044,16 +3184,51 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
if hasattr(self, "afunc"):
|
if hasattr(self, "afunc"):
|
||||||
afunc = self.afunc
|
afunc = self.afunc
|
||||||
else:
|
else:
|
||||||
|
if inspect.isgeneratorfunction(self.func):
|
||||||
|
raise TypeError(
|
||||||
|
"Cannot stream from a generator function asynchronously."
|
||||||
|
"Use .stream() instead."
|
||||||
|
)
|
||||||
|
|
||||||
@wraps(self.func)
|
def func(
|
||||||
|
input: Input,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
return call_func_with_variable_args(
|
||||||
|
self.func, input, config, run_manager.get_sync(), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
return await run_in_executor(config, self.func, *args, **kwargs)
|
return await run_in_executor(config, func, *args, **kwargs)
|
||||||
|
|
||||||
afunc = f
|
afunc = f
|
||||||
|
|
||||||
output = await acall_func_with_variable_args(
|
if inspect.isasyncgenfunction(afunc):
|
||||||
afunc, cast(Input, final), config, run_manager
|
output: Optional[Output] = None
|
||||||
)
|
async for chunk in cast(
|
||||||
|
AsyncIterator[Output],
|
||||||
|
acall_func_with_variable_args(
|
||||||
|
cast(Callable, afunc),
|
||||||
|
cast(Input, final),
|
||||||
|
config,
|
||||||
|
run_manager,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
if output is None:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
output = output + chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
output = chunk
|
||||||
|
else:
|
||||||
|
output = await acall_func_with_variable_args(
|
||||||
|
cast(Callable, afunc), cast(Input, final), config, run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# If the output is a runnable, use its astream output
|
# If the output is a runnable, use its astream output
|
||||||
if isinstance(output, Runnable):
|
if isinstance(output, Runnable):
|
||||||
@ -3072,9 +3247,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
elif not inspect.isasyncgenfunction(afunc):
|
||||||
# Otherwise, just yield it
|
# Otherwise, just yield it
|
||||||
yield output
|
yield cast(Output, output)
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
@ -3699,3 +3874,69 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
|||||||
f"Expected a Runnable, callable or dict."
|
f"Expected a Runnable, callable or dict."
|
||||||
f"Instead got an unsupported type: {type(thing)}"
|
f"Instead got an unsupported type: {type(thing)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def chain(
|
||||||
|
func: Callable[[Input], Coroutine[Any, Any, Output]],
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def chain(
|
||||||
|
func: Callable[[Input], Iterator[Output]],
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def chain(
|
||||||
|
func: Callable[[Input], AsyncIterator[Output]],
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def chain(
|
||||||
|
func: Callable[[Input], Output],
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def chain(
|
||||||
|
func: Union[
|
||||||
|
Callable[[Input], Output],
|
||||||
|
Callable[[Input], Iterator[Output]],
|
||||||
|
Callable[[Input], Coroutine[Any, Any, Output]],
|
||||||
|
Callable[[Input], AsyncIterator[Output]],
|
||||||
|
],
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
"""Decorate a function to make it a Runnable.
|
||||||
|
Sets the name of the runnable to the name of the function.
|
||||||
|
Any runnables called by the function will be traced as dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: A callable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.runnables import chain
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
|
@chain
|
||||||
|
def my_func(fields):
|
||||||
|
prompt = PromptTemplate("Hello, {name}!")
|
||||||
|
llm = OpenAI()
|
||||||
|
formatted = prompt.invoke(**fields)
|
||||||
|
|
||||||
|
for chunk in llm.stream(formatted):
|
||||||
|
yield chunk
|
||||||
|
"""
|
||||||
|
return RunnableLambda(func)
|
||||||
|
@ -323,7 +323,7 @@ def call_func_with_variable_args(
|
|||||||
return func(input, **kwargs) # type: ignore[call-arg]
|
return func(input, **kwargs) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
async def acall_func_with_variable_args(
|
def acall_func_with_variable_args(
|
||||||
func: Union[
|
func: Union[
|
||||||
Callable[[Input], Awaitable[Output]],
|
Callable[[Input], Awaitable[Output]],
|
||||||
Callable[[Input, RunnableConfig], Awaitable[Output]],
|
Callable[[Input, RunnableConfig], Awaitable[Output]],
|
||||||
@ -337,7 +337,7 @@ async def acall_func_with_variable_args(
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Output:
|
) -> Awaitable[Output]:
|
||||||
"""Call function that may optionally accept a run_manager and/or config.
|
"""Call function that may optionally accept a run_manager and/or config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -361,7 +361,7 @@ async def acall_func_with_variable_args(
|
|||||||
kwargs["config"] = config
|
kwargs["config"] = config
|
||||||
if run_manager is not None and accepts_run_manager(func):
|
if run_manager is not None and accepts_run_manager(func):
|
||||||
kwargs["run_manager"] = run_manager
|
kwargs["run_manager"] = run_manager
|
||||||
return await func(input, **kwargs) # type: ignore[call-arg]
|
return func(input, **kwargs) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||||
|
@ -68,6 +68,14 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def accepts_context(callable: Callable[..., Any]) -> bool:
|
||||||
|
"""Check if a callable accepts a context argument."""
|
||||||
|
try:
|
||||||
|
return signature(callable).parameters.get("context") is not None
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class IsLocalDict(ast.NodeVisitor):
|
class IsLocalDict(ast.NodeVisitor):
|
||||||
"""Check if a name is a local dict."""
|
"""Check if a name is a local dict."""
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from langchain_core.runnables.utils import aadd, add
|
|||||||
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||||
|
|
||||||
|
|
||||||
class TestCase(NamedTuple):
|
class _TestCase(NamedTuple):
|
||||||
input: Any
|
input: Any
|
||||||
output: Any
|
output: Any
|
||||||
|
|
||||||
@ -102,22 +102,22 @@ test_cases = [
|
|||||||
(
|
(
|
||||||
Context.setter("foo") | Context.getter("foo"),
|
Context.setter("foo") | Context.getter("foo"),
|
||||||
(
|
(
|
||||||
TestCase("foo", "foo"),
|
_TestCase("foo", "foo"),
|
||||||
TestCase("bar", "bar"),
|
_TestCase("bar", "bar"),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
Context.setter("input") | {"bar": Context.getter("input")},
|
Context.setter("input") | {"bar": Context.getter("input")},
|
||||||
(
|
(
|
||||||
TestCase("foo", {"bar": "foo"}),
|
_TestCase("foo", {"bar": "foo"}),
|
||||||
TestCase("bar", {"bar": "bar"}),
|
_TestCase("bar", {"bar": "bar"}),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
{"bar": Context.setter("input")} | Context.getter("input"),
|
{"bar": Context.setter("input")} | Context.getter("input"),
|
||||||
(
|
(
|
||||||
TestCase("foo", "foo"),
|
_TestCase("foo", "foo"),
|
||||||
TestCase("bar", "bar"),
|
_TestCase("bar", "bar"),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -132,11 +132,11 @@ test_cases = [
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "foo", "bar": "bar"},
|
{"foo": "foo", "bar": "bar"},
|
||||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||||
),
|
),
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "bar", "bar": "foo"},
|
{"foo": "bar", "bar": "foo"},
|
||||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||||
),
|
),
|
||||||
@ -155,7 +155,7 @@ test_cases = [
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "foo", "bar": "bar"},
|
{"foo": "foo", "bar": "bar"},
|
||||||
{
|
{
|
||||||
"response": "hello",
|
"response": "hello",
|
||||||
@ -163,7 +163,7 @@ test_cases = [
|
|||||||
"prompt_str": "foo bar",
|
"prompt_str": "foo bar",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "bar", "bar": "foo"},
|
{"foo": "bar", "bar": "foo"},
|
||||||
{
|
{
|
||||||
"response": "hello",
|
"response": "hello",
|
||||||
@ -185,11 +185,11 @@ test_cases = [
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "foo", "bar": "bar"},
|
{"foo": "foo", "bar": "bar"},
|
||||||
{"response": "hello", "prompt_str": "foo bar"},
|
{"response": "hello", "prompt_str": "foo bar"},
|
||||||
),
|
),
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "bar", "bar": "foo"},
|
{"foo": "bar", "bar": "foo"},
|
||||||
{"response": "hello", "prompt_str": "bar foo"},
|
{"response": "hello", "prompt_str": "bar foo"},
|
||||||
),
|
),
|
||||||
@ -207,11 +207,11 @@ test_cases = [
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "foo", "bar": "bar"},
|
{"foo": "foo", "bar": "bar"},
|
||||||
{"response": "hello", "prompt_str": "foo bar"},
|
{"response": "hello", "prompt_str": "foo bar"},
|
||||||
),
|
),
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "bar", "bar": "foo"},
|
{"foo": "bar", "bar": "foo"},
|
||||||
{"response": "hello", "prompt_str": "bar foo"},
|
{"response": "hello", "prompt_str": "bar foo"},
|
||||||
),
|
),
|
||||||
@ -229,11 +229,11 @@ test_cases = [
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "foo", "bar": "bar"},
|
{"foo": "foo", "bar": "bar"},
|
||||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||||
),
|
),
|
||||||
TestCase(
|
_TestCase(
|
||||||
{"foo": "bar", "bar": "foo"},
|
{"foo": "bar", "bar": "foo"},
|
||||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||||
),
|
),
|
||||||
@ -242,7 +242,7 @@ test_cases = [
|
|||||||
(
|
(
|
||||||
seq_naive_rag,
|
seq_naive_rag,
|
||||||
(
|
(
|
||||||
TestCase(
|
_TestCase(
|
||||||
"What up",
|
"What up",
|
||||||
{
|
{
|
||||||
"result": "hello",
|
"result": "hello",
|
||||||
@ -254,7 +254,7 @@ test_cases = [
|
|||||||
"input": "What up",
|
"input": "What up",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
TestCase(
|
_TestCase(
|
||||||
"Howdy",
|
"Howdy",
|
||||||
{
|
{
|
||||||
"result": "hello",
|
"result": "hello",
|
||||||
@ -271,7 +271,7 @@ test_cases = [
|
|||||||
(
|
(
|
||||||
seq_naive_rag_alt,
|
seq_naive_rag_alt,
|
||||||
(
|
(
|
||||||
TestCase(
|
_TestCase(
|
||||||
"What up",
|
"What up",
|
||||||
{
|
{
|
||||||
"result": "hello",
|
"result": "hello",
|
||||||
@ -283,7 +283,7 @@ test_cases = [
|
|||||||
"input": "What up",
|
"input": "What up",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
TestCase(
|
_TestCase(
|
||||||
"Howdy",
|
"Howdy",
|
||||||
{
|
{
|
||||||
"result": "hello",
|
"result": "hello",
|
||||||
@ -300,7 +300,7 @@ test_cases = [
|
|||||||
(
|
(
|
||||||
seq_naive_rag_scoped,
|
seq_naive_rag_scoped,
|
||||||
(
|
(
|
||||||
TestCase(
|
_TestCase(
|
||||||
"What up",
|
"What up",
|
||||||
{
|
{
|
||||||
"result": "hello",
|
"result": "hello",
|
||||||
@ -312,7 +312,7 @@ test_cases = [
|
|||||||
"input": "What up",
|
"input": "What up",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
TestCase(
|
_TestCase(
|
||||||
"Howdy",
|
"Howdy",
|
||||||
{
|
{
|
||||||
"result": "hello",
|
"result": "hello",
|
||||||
@ -331,7 +331,7 @@ test_cases = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
@pytest.mark.parametrize("runnable, cases", test_cases)
|
||||||
async def test_context_runnables(
|
async def test_context_runnables(
|
||||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[TestCase]
|
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[_TestCase]
|
||||||
) -> None:
|
) -> None:
|
||||||
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
||||||
assert runnable.invoke(cases[0].input) == cases[0].output
|
assert runnable.invoke(cases[0].input) == cases[0].output
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from langchain_core.runnables import __all__
|
from langchain_core.runnables import __all__
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
EXPECTED_ALL = [
|
||||||
|
"chain",
|
||||||
"AddableDict",
|
"AddableDict",
|
||||||
"ConfigurableField",
|
"ConfigurableField",
|
||||||
"ConfigurableFieldSingleOption",
|
"ConfigurableFieldSingleOption",
|
||||||
|
@ -68,6 +68,7 @@ from langchain_core.runnables import (
|
|||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
add,
|
add,
|
||||||
|
chain,
|
||||||
)
|
)
|
||||||
from langchain_core.tools import BaseTool, tool
|
from langchain_core.tools import BaseTool, tool
|
||||||
from langchain_core.tracers import (
|
from langchain_core.tracers import (
|
||||||
@ -4388,9 +4389,9 @@ async def test_runnable_gen() -> None:
|
|||||||
|
|
||||||
runnable = RunnableGenerator(gen)
|
runnable = RunnableGenerator(gen)
|
||||||
|
|
||||||
assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"}
|
assert runnable.input_schema.schema() == {"title": "gen_input"}
|
||||||
assert runnable.output_schema.schema() == {
|
assert runnable.output_schema.schema() == {
|
||||||
"title": "RunnableGeneratorOutput",
|
"title": "gen_output",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4410,6 +4411,315 @@ async def test_runnable_gen() -> None:
|
|||||||
assert await arunnable.abatch([None, None]) == [6, 6]
|
assert await arunnable.abatch([None, None]) == [6, 6]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_runnable_gen_context_config() -> None:
|
||||||
|
"""Test that a generator can call other runnables with config
|
||||||
|
propagated from the context."""
|
||||||
|
|
||||||
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
|
def gen(input: Iterator[Any]) -> Iterator[int]:
|
||||||
|
yield fake.invoke("a")
|
||||||
|
yield fake.invoke("aa")
|
||||||
|
yield fake.invoke("aaa")
|
||||||
|
|
||||||
|
runnable = RunnableGenerator(gen)
|
||||||
|
|
||||||
|
assert runnable.input_schema.schema() == {"title": "gen_input"}
|
||||||
|
assert runnable.output_schema.schema() == {
|
||||||
|
"title": "gen_output",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
tracer.runs.clear()
|
||||||
|
|
||||||
|
assert list(runnable.stream(None)) == [1, 2, 3]
|
||||||
|
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3]
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6]
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert tracer.runs[1].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
assert len(tracer.runs[1].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
if sys.version_info < (3, 11):
|
||||||
|
# Python 3.10 and below don't support running async tasks in a specific context
|
||||||
|
return
|
||||||
|
|
||||||
|
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||||
|
yield await fake.ainvoke("a")
|
||||||
|
yield await fake.ainvoke("aa")
|
||||||
|
yield await fake.ainvoke("aaa")
|
||||||
|
|
||||||
|
arunnable = RunnableGenerator(agen)
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
tracer.runs.clear()
|
||||||
|
|
||||||
|
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
||||||
|
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
]
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6]
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert tracer.runs[1].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
assert len(tracer.runs[1].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_runnable_iter_context_config() -> None:
|
||||||
|
"""Test that a generator can call other runnables with config
|
||||||
|
propagated from the context."""
|
||||||
|
|
||||||
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
|
@chain
|
||||||
|
def gen(input: str) -> Iterator[int]:
|
||||||
|
yield fake.invoke(input)
|
||||||
|
yield fake.invoke(input * 2)
|
||||||
|
yield fake.invoke(input * 3)
|
||||||
|
|
||||||
|
assert gen.input_schema.schema() == {
|
||||||
|
"title": "gen_input",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
assert gen.output_schema.schema() == {
|
||||||
|
"title": "gen_output",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert gen.invoke("a", {"callbacks": [tracer]}) == 6
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
tracer.runs.clear()
|
||||||
|
|
||||||
|
assert list(gen.stream("a")) == [1, 2, 3]
|
||||||
|
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert list(gen.stream("a", {"callbacks": [tracer]})) == [1, 2, 3]
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert gen.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert tracer.runs[1].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
assert len(tracer.runs[1].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
if sys.version_info < (3, 11):
|
||||||
|
# Python 3.10 and below don't support running async tasks in a specific context
|
||||||
|
return
|
||||||
|
|
||||||
|
@chain
|
||||||
|
async def agen(input: str) -> AsyncIterator[int]:
|
||||||
|
yield await fake.ainvoke(input)
|
||||||
|
yield await fake.ainvoke(input * 2)
|
||||||
|
yield await fake.ainvoke(input * 3)
|
||||||
|
|
||||||
|
assert agen.input_schema.schema() == {
|
||||||
|
"title": "agen_input",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
assert agen.output_schema.schema() == {
|
||||||
|
"title": "agen_output",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await agen.ainvoke("a", {"callbacks": [tracer]}) == 6
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
tracer.runs.clear()
|
||||||
|
|
||||||
|
assert [p async for p in agen.astream("a")] == [1, 2, 3]
|
||||||
|
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert [p async for p in agen.astream("a", {"callbacks": [tracer]})] == [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
]
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert tracer.runs[1].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
assert len(tracer.runs[1].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_runnable_lambda_context_config() -> None:
|
||||||
|
"""Test that a function can call other runnables with config
|
||||||
|
propagated from the context."""
|
||||||
|
|
||||||
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
|
@chain
|
||||||
|
def fun(input: str) -> int:
|
||||||
|
output = fake.invoke(input)
|
||||||
|
output += fake.invoke(input * 2)
|
||||||
|
output += fake.invoke(input * 3)
|
||||||
|
return output
|
||||||
|
|
||||||
|
assert fun.input_schema.schema() == {"title": "fun_input", "type": "string"}
|
||||||
|
assert fun.output_schema.schema() == {
|
||||||
|
"title": "fun_output",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert fun.invoke("a", {"callbacks": [tracer]}) == 6
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
tracer.runs.clear()
|
||||||
|
|
||||||
|
assert list(fun.stream("a")) == [6]
|
||||||
|
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert list(fun.stream("a", {"callbacks": [tracer]})) == [6]
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert fun.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert tracer.runs[1].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
assert len(tracer.runs[1].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
if sys.version_info < (3, 11):
|
||||||
|
# Python 3.10 and below don't support running async tasks in a specific context
|
||||||
|
return
|
||||||
|
|
||||||
|
@chain
|
||||||
|
async def afun(input: str) -> int:
|
||||||
|
output = await fake.ainvoke(input)
|
||||||
|
output += await fake.ainvoke(input * 2)
|
||||||
|
output += await fake.ainvoke(input * 3)
|
||||||
|
return output
|
||||||
|
|
||||||
|
assert afun.input_schema.schema() == {"title": "afun_input", "type": "string"}
|
||||||
|
assert afun.output_schema.schema() == {
|
||||||
|
"title": "afun_output",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await afun.ainvoke("a", {"callbacks": [tracer]}) == 6
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
tracer.runs.clear()
|
||||||
|
|
||||||
|
assert [p async for p in afun.astream("a")] == [6]
|
||||||
|
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert [p async for p in afun.astream("a", {"callbacks": [tracer]})] == [6]
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await afun.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert tracer.runs[0].outputs == {"output": 6}
|
||||||
|
assert tracer.runs[1].outputs == {"output": 6}
|
||||||
|
assert len(tracer.runs[0].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||||
|
assert len(tracer.runs[1].child_runs) == 3
|
||||||
|
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||||
|
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
async def test_runnable_gen_transform() -> None:
|
async def test_runnable_gen_transform() -> None:
|
||||||
"""Test that a generator can be used as a runnable."""
|
"""Test that a generator can be used as a runnable."""
|
||||||
|
|
||||||
@ -4434,19 +4744,19 @@ async def test_runnable_gen_transform() -> None:
|
|||||||
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
|
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
|
||||||
|
|
||||||
assert chain.input_schema.schema() == {
|
assert chain.input_schema.schema() == {
|
||||||
"title": "RunnableGeneratorInput",
|
"title": "gen_indexes_input",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
}
|
}
|
||||||
assert chain.output_schema.schema() == {
|
assert chain.output_schema.schema() == {
|
||||||
"title": "RunnableGeneratorOutput",
|
"title": "plus_one_output",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
}
|
}
|
||||||
assert achain.input_schema.schema() == {
|
assert achain.input_schema.schema() == {
|
||||||
"title": "RunnableGeneratorInput",
|
"title": "gen_indexes_input",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
}
|
}
|
||||||
assert achain.output_schema.schema() == {
|
assert achain.output_schema.schema() == {
|
||||||
"title": "RunnableGeneratorOutput",
|
"title": "aplus_one_output",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user