diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py deleted file mode 100644 index 2be721387cb..00000000000 --- a/libs/core/langchain_core/beta/runnables/context.py +++ /dev/null @@ -1,401 +0,0 @@ -import asyncio -import threading -from collections import defaultdict -from collections.abc import Awaitable, Mapping, Sequence -from functools import partial -from itertools import groupby -from typing import ( - Any, - Callable, - Optional, - TypeVar, - Union, -) - -from pydantic import ConfigDict - -from langchain_core._api.beta_decorator import beta -from langchain_core.runnables.base import ( - Runnable, - RunnableSerializable, - coerce_to_runnable, -) -from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config -from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output - -T = TypeVar("T") -Values = dict[Union[asyncio.Event, threading.Event], Any] -CONTEXT_CONFIG_PREFIX = "__context__/" -CONTEXT_CONFIG_SUFFIX_GET = "/get" -CONTEXT_CONFIG_SUFFIX_SET = "/set" - - -async def _asetter(done: asyncio.Event, values: Values, value: T) -> T: - values[done] = value - done.set() - return value - - -async def _agetter(done: asyncio.Event, values: Values) -> Any: - await done.wait() - return values[done] - - -def _setter(done: threading.Event, values: Values, value: T) -> T: - values[done] = value - done.set() - return value - - -def _getter(done: threading.Event, values: Values) -> Any: - done.wait() - return values[done] - - -def _key_from_id(id_: str) -> str: - wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1] - if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET): - return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)] - elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET): - return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)] - else: - msg = f"Invalid context config id {id_}" - raise ValueError(msg) - - -def _config_with_context( - config: RunnableConfig, - steps: list[Runnable], - setter: Callable, - getter: Callable, - event_cls: Union[type[threading.Event], type[asyncio.Event]], -) -> RunnableConfig: - if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): - return config - - context_specs = [ - (spec, i) - for i, step in enumerate(steps) - for spec in step.config_specs - if spec.id.startswith(CONTEXT_CONFIG_PREFIX) - ] - grouped_by_key = { - key: list(group) - for key, group in groupby( - sorted(context_specs, key=lambda s: s[0].id), - key=lambda s: _key_from_id(s[0].id), - ) - } - deps_by_key = { - key: { - _key_from_id(dep) for spec in group for dep in (spec[0].dependencies or []) - } - for key, group in grouped_by_key.items() - } - - values: Values = {} - events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict( - event_cls - ) - context_funcs: dict[str, Callable[[], Any]] = {} - for key, group in grouped_by_key.items(): - getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)] - setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)] - - for dep in deps_by_key[key]: - if key in deps_by_key[dep]: - msg = f"Deadlock detected between context keys {key} and {dep}" - raise ValueError(msg) - if len(setters) != 1: - msg = f"Expected exactly one setter for context key {key}" - raise ValueError(msg) - setter_idx = setters[0][1] - if any(getter_idx < setter_idx for _, getter_idx in getters): - msg = f"Context setter for key {key} must be defined after all getters." - raise ValueError(msg) - - if getters: - context_funcs[getters[0][0].id] = partial(getter, events[key], values) - context_funcs[setters[0][0].id] = partial(setter, events[key], values) - - return patch_config(config, configurable=context_funcs) - - -def aconfig_with_context( - config: RunnableConfig, - steps: list[Runnable], -) -> RunnableConfig: - """Asynchronously patch a runnable config with context getters and setters. - - Args: - config: The runnable config. - steps: The runnable steps. - - Returns: - The patched runnable config. - """ - return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event) - - -def config_with_context( - config: RunnableConfig, - steps: list[Runnable], -) -> RunnableConfig: - """Patch a runnable config with context getters and setters. - - Args: - config: The runnable config. - steps: The runnable steps. - - Returns: - The patched runnable config. - """ - return _config_with_context(config, steps, _setter, _getter, threading.Event) - - -@beta() -class ContextGet(RunnableSerializable): - """Get a context value.""" - - prefix: str = "" - - key: Union[str, list[str]] - - def __str__(self) -> str: - return f"ContextGet({_print_keys(self.key)})" - - @property - def ids(self) -> list[str]: - prefix = self.prefix + "/" if self.prefix else "" - keys = self.key if isinstance(self.key, list) else [self.key] - return [ - f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}" - for k in keys - ] - - @property - def config_specs(self) -> list[ConfigurableFieldSpec]: - return super().config_specs + [ - ConfigurableFieldSpec( - id=id_, - annotation=Callable[[], Any], - ) - for id_ in self.ids - ] - - def invoke( - self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Any: - config = ensure_config(config) - configurable = config.get("configurable", {}) - if isinstance(self.key, list): - return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)} - else: - return configurable[self.ids[0]]() - - async def ainvoke( - self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Any: - config = ensure_config(config) - configurable = config.get("configurable", {}) - if isinstance(self.key, list): - values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids)) - return dict(zip(self.key, values)) - else: - return await configurable[self.ids[0]]() - - -SetValue = Union[ - Runnable[Input, Output], - Callable[[Input], Output], - Callable[[Input], Awaitable[Output]], - Any, -] - - -def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]: - if not isinstance(value, Runnable) and not callable(value): - return coerce_to_runnable(lambda _: value) - return coerce_to_runnable(value) - - -@beta() -class ContextSet(RunnableSerializable): - """Set a context value.""" - - prefix: str = "" - - keys: Mapping[str, Optional[Runnable]] - - model_config = ConfigDict( - arbitrary_types_allowed=True, - ) - - def __init__( - self, - key: Optional[str] = None, - value: Optional[SetValue] = None, - prefix: str = "", - **kwargs: SetValue, - ): - if key is not None: - kwargs[key] = value - super().__init__( # type: ignore[call-arg] - keys={ - k: _coerce_set_value(v) if v is not None else None - for k, v in kwargs.items() - }, - prefix=prefix, - ) - - def __str__(self) -> str: - return f"ContextSet({_print_keys(list(self.keys.keys()))})" - - @property - def ids(self) -> list[str]: - prefix = self.prefix + "/" if self.prefix else "" - return [ - f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}" - for key in self.keys - ] - - @property - def config_specs(self) -> list[ConfigurableFieldSpec]: - mapper_config_specs = [ - s - for mapper in self.keys.values() - if mapper is not None - for s in mapper.config_specs - ] - for spec in mapper_config_specs: - if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET): - getter_key = spec.id.split("/")[1] - if getter_key in self.keys: - msg = f"Circular reference in context setter for key {getter_key}" - raise ValueError(msg) - return super().config_specs + [ - ConfigurableFieldSpec( - id=id_, - annotation=Callable[[], Any], - ) - for id_ in self.ids - ] - - def invoke( - self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Any: - config = ensure_config(config) - configurable = config.get("configurable", {}) - for id_, mapper in zip(self.ids, self.keys.values()): - if mapper is not None: - configurable[id_](mapper.invoke(input, config)) - else: - configurable[id_](input) - return input - - async def ainvoke( - self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Any: - config = ensure_config(config) - configurable = config.get("configurable", {}) - for id_, mapper in zip(self.ids, self.keys.values()): - if mapper is not None: - await configurable[id_](await mapper.ainvoke(input, config)) - else: - await configurable[id_](input) - return input - - -class Context: - """Context for a runnable. - - The `Context` class provides methods for creating context scopes, - getters, and setters within a runnable. It allows for managing - and accessing contextual information throughout the execution - of a program. - - Example: - .. code-block:: python - - from langchain_core.beta.runnables.context import Context - from langchain_core.runnables.passthrough import RunnablePassthrough - from langchain_core.prompts.prompt import PromptTemplate - from langchain_core.output_parsers.string import StrOutputParser - from tests.unit_tests.fake.llm import FakeListLLM - - chain = ( - Context.setter("input") - | { - "context": RunnablePassthrough() - | Context.setter("context"), - "question": RunnablePassthrough(), - } - | PromptTemplate.from_template("{context} {question}") - | FakeListLLM(responses=["hello"]) - | StrOutputParser() - | { - "result": RunnablePassthrough(), - "context": Context.getter("context"), - "input": Context.getter("input"), - } - ) - - # Use the chain - output = chain.invoke("What's your name?") - print(output["result"]) # Output: "hello" - print(output["context"]) # Output: "What's your name?" - print(output["input"]) # Output: "What's your name? - """ - - @staticmethod - def create_scope(scope: str, /) -> "PrefixContext": - """Create a context scope. - - Args: - scope: The scope. - - Returns: - The context scope. - """ - return PrefixContext(prefix=scope) - - @staticmethod - def getter(key: Union[str, list[str]], /) -> ContextGet: - return ContextGet(key=key) - - @staticmethod - def setter( - _key: Optional[str] = None, - _value: Optional[SetValue] = None, - /, - **kwargs: SetValue, - ) -> ContextSet: - return ContextSet(_key, _value, prefix="", **kwargs) - - -class PrefixContext: - """Context for a runnable with a prefix.""" - - prefix: str = "" - - def __init__(self, prefix: str = ""): - self.prefix = prefix - - def getter(self, key: Union[str, list[str]], /) -> ContextGet: - return ContextGet(key=key, prefix=self.prefix) - - def setter( - self, - _key: Optional[str] = None, - _value: Optional[SetValue] = None, - /, - **kwargs: SetValue, - ) -> ContextSet: - return ContextSet(_key, _value, prefix=self.prefix, **kwargs) - - -def _print_keys(keys: Union[str, Sequence[str]]) -> str: - if isinstance(keys, str): - return f"'{keys}'" - else: - return ", ".join(f"'{k}'" for k in keys) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 4fbb1f96201..f9527df277d 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -18,7 +18,7 @@ from collections.abc import ( ) from concurrent.futures import FIRST_COMPLETED, wait from functools import wraps -from itertools import groupby, tee +from itertools import tee from operator import itemgetter from types import GenericAlias from typing import ( @@ -2858,50 +2858,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): Returns: The config specs of the Runnable. """ - from langchain_core.beta.runnables.context import ( - CONTEXT_CONFIG_PREFIX, - _key_from_id, + return get_unique_config_specs( + [spec for step in self.steps for spec in step.config_specs] ) - # get all specs - all_specs = [ - (spec, idx) - for idx, step in enumerate(self.steps) - for spec in step.config_specs - ] - # calculate context dependencies - specs_by_pos = groupby( - [tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)], - itemgetter(1), - ) - next_deps: set[str] = set() - deps_by_pos: dict[int, set[str]] = {} - for pos, specs in specs_by_pos: - deps_by_pos[pos] = next_deps - next_deps = next_deps | {spec[0].id for spec in specs} - # assign context dependencies - for pos, (spec, idx) in enumerate(all_specs): - if spec.id.startswith(CONTEXT_CONFIG_PREFIX): - all_specs[pos] = ( - ConfigurableFieldSpec( - id=spec.id, - annotation=spec.annotation, - name=spec.name, - default=spec.default, - description=spec.description, - is_shared=spec.is_shared, - dependencies=[ - d - for d in deps_by_pos[idx] - if _key_from_id(d) != _key_from_id(spec.id) - ] - + (spec.dependencies or []), - ), - idx, - ) - - return get_unique_config_specs(spec for spec, _ in all_specs) - def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: """Get the graph representation of the Runnable. @@ -2998,10 +2958,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - from langchain_core.beta.runnables.context import config_with_context - # setup callbacks and context - config = config_with_context(ensure_config(config), self.steps) + config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -3037,10 +2995,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - from langchain_core.beta.runnables.context import aconfig_with_context - - # setup callbacks and context - config = aconfig_with_context(ensure_config(config), self.steps) + config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -3082,17 +3037,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return_exceptions: bool = False, **kwargs: Optional[Any], ) -> list[Output]: - from langchain_core.beta.runnables.context import config_with_context from langchain_core.callbacks.manager import CallbackManager if not inputs: return [] # setup callbacks and context - configs = [ - config_with_context(c, self.steps) - for c in get_config_list(config, len(inputs)) - ] + configs = get_config_list(config, len(inputs)) callback_managers = [ CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -3209,17 +3160,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return_exceptions: bool = False, **kwargs: Optional[Any], ) -> list[Output]: - from langchain_core.beta.runnables.context import aconfig_with_context from langchain_core.callbacks.manager import AsyncCallbackManager if not inputs: return [] # setup callbacks and context - configs = [ - aconfig_with_context(c, self.steps) - for c in get_config_list(config, len(inputs)) - ] + configs = get_config_list(config, len(inputs)) callback_managers = [ AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -3338,10 +3285,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> Iterator[Output]: - from langchain_core.beta.runnables.context import config_with_context - steps = [self.first] + self.middle + [self.last] - config = config_with_context(config, self.steps) # transform the input stream of each step with the next # steps that don't natively support transforming an input stream will @@ -3365,10 +3309,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> AsyncIterator[Output]: - from langchain_core.beta.runnables.context import aconfig_with_context - steps = [self.first] + self.middle + [self.last] - config = aconfig_with_context(config, self.steps) # stream the last steps # transform the input stream of each step with the next diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 56c43886189..24903490aff 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -168,11 +168,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]): @property def config_specs(self) -> list[ConfigurableFieldSpec]: - from langchain_core.beta.runnables.context import ( - CONTEXT_CONFIG_PREFIX, - CONTEXT_CONFIG_SUFFIX_SET, - ) - specs = get_unique_config_specs( spec for step in ( @@ -182,13 +177,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]): ) for spec in step.config_specs ) - if any( - s.id.startswith(CONTEXT_CONFIG_PREFIX) - and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET) - for s in specs - ): - msg = "RunnableBranch cannot contain context setters." - raise ValueError(msg) return specs def invoke( diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py deleted file mode 100644 index cb8de2dd808..00000000000 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ /dev/null @@ -1,418 +0,0 @@ -import asyncio -from typing import Any, Callable, NamedTuple, Union - -import pytest - -from langchain_core.beta.runnables.context import Context -from langchain_core.language_models import FakeListLLM, FakeStreamingListLLM -from langchain_core.output_parsers.string import StrOutputParser -from langchain_core.prompt_values import StringPromptValue -from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.runnables.base import Runnable, RunnableLambda -from langchain_core.runnables.passthrough import RunnablePassthrough -from langchain_core.runnables.utils import aadd, add - - -class _TestCase(NamedTuple): - input: Any - output: Any - - -def seq_naive_rag() -> Runnable: - context = [ - "Hi there!", - "How are you?", - "What's your name?", - ] - - retriever = RunnableLambda(lambda x: context) - prompt = PromptTemplate.from_template("{context} {question}") - llm = FakeListLLM(responses=["hello"]) - - return ( - Context.setter("input") - | { - "context": retriever | Context.setter("context"), - "question": RunnablePassthrough(), - } - | prompt - | llm - | StrOutputParser() - | { - "result": RunnablePassthrough(), - "context": Context.getter("context"), - "input": Context.getter("input"), - } - ) - - -def seq_naive_rag_alt() -> Runnable: - context = [ - "Hi there!", - "How are you?", - "What's your name?", - ] - - retriever = RunnableLambda(lambda x: context) - prompt = PromptTemplate.from_template("{context} {question}") - llm = FakeListLLM(responses=["hello"]) - - return ( - Context.setter("input") - | { - "context": retriever | Context.setter("context"), - "question": RunnablePassthrough(), - } - | prompt - | llm - | StrOutputParser() - | Context.setter("result") - | Context.getter(["context", "input", "result"]) - ) - - -def seq_naive_rag_scoped() -> Runnable: - context = [ - "Hi there!", - "How are you?", - "What's your name?", - ] - - retriever = RunnableLambda(lambda x: context) - prompt = PromptTemplate.from_template("{context} {question}") - llm = FakeListLLM(responses=["hello"]) - - scoped = Context.create_scope("a_scope") - - return ( - Context.setter("input") - | { - "context": retriever | Context.setter("context"), - "question": RunnablePassthrough(), - "scoped": scoped.setter("context") | scoped.getter("context"), - } - | prompt - | llm - | StrOutputParser() - | Context.setter("result") - | Context.getter(["context", "input", "result"]) - ) - - -test_cases = [ - ( - Context.setter("foo") | Context.getter("foo"), - ( - _TestCase("foo", "foo"), - _TestCase("bar", "bar"), - ), - ), - ( - Context.setter("input") | {"bar": Context.getter("input")}, - ( - _TestCase("foo", {"bar": "foo"}), - _TestCase("bar", {"bar": "bar"}), - ), - ), - ( - {"bar": Context.setter("input")} | Context.getter("input"), - ( - _TestCase("foo", "foo"), - _TestCase("bar", "bar"), - ), - ), - ( - ( - PromptTemplate.from_template("{foo} {bar}") - | Context.setter("prompt") - | FakeListLLM(responses=["hello"]) - | StrOutputParser() - | { - "response": RunnablePassthrough(), - "prompt": Context.getter("prompt"), - } - ), - ( - _TestCase( - {"foo": "foo", "bar": "bar"}, - {"response": "hello", "prompt": StringPromptValue(text="foo bar")}, - ), - _TestCase( - {"foo": "bar", "bar": "foo"}, - {"response": "hello", "prompt": StringPromptValue(text="bar foo")}, - ), - ), - ), - ( - ( - PromptTemplate.from_template("{foo} {bar}") - | Context.setter("prompt", prompt_str=lambda x: x.to_string()) - | FakeListLLM(responses=["hello"]) - | StrOutputParser() - | { - "response": RunnablePassthrough(), - "prompt": Context.getter("prompt"), - "prompt_str": Context.getter("prompt_str"), - } - ), - ( - _TestCase( - {"foo": "foo", "bar": "bar"}, - { - "response": "hello", - "prompt": StringPromptValue(text="foo bar"), - "prompt_str": "foo bar", - }, - ), - _TestCase( - {"foo": "bar", "bar": "foo"}, - { - "response": "hello", - "prompt": StringPromptValue(text="bar foo"), - "prompt_str": "bar foo", - }, - ), - ), - ), - ( - ( - PromptTemplate.from_template("{foo} {bar}") - | Context.setter(prompt_str=lambda x: x.to_string()) - | FakeListLLM(responses=["hello"]) - | StrOutputParser() - | { - "response": RunnablePassthrough(), - "prompt_str": Context.getter("prompt_str"), - } - ), - ( - _TestCase( - {"foo": "foo", "bar": "bar"}, - {"response": "hello", "prompt_str": "foo bar"}, - ), - _TestCase( - {"foo": "bar", "bar": "foo"}, - {"response": "hello", "prompt_str": "bar foo"}, - ), - ), - ), - ( - ( - PromptTemplate.from_template("{foo} {bar}") - | Context.setter("prompt_str", lambda x: x.to_string()) - | FakeListLLM(responses=["hello"]) - | StrOutputParser() - | { - "response": RunnablePassthrough(), - "prompt_str": Context.getter("prompt_str"), - } - ), - ( - _TestCase( - {"foo": "foo", "bar": "bar"}, - {"response": "hello", "prompt_str": "foo bar"}, - ), - _TestCase( - {"foo": "bar", "bar": "foo"}, - {"response": "hello", "prompt_str": "bar foo"}, - ), - ), - ), - ( - ( - PromptTemplate.from_template("{foo} {bar}") - | Context.setter("prompt") - | FakeStreamingListLLM(responses=["hello"]) - | StrOutputParser() - | { - "response": RunnablePassthrough(), - "prompt": Context.getter("prompt"), - } - ), - ( - _TestCase( - {"foo": "foo", "bar": "bar"}, - {"response": "hello", "prompt": StringPromptValue(text="foo bar")}, - ), - _TestCase( - {"foo": "bar", "bar": "foo"}, - {"response": "hello", "prompt": StringPromptValue(text="bar foo")}, - ), - ), - ), - ( - seq_naive_rag, - ( - _TestCase( - "What up", - { - "result": "hello", - "context": [ - "Hi there!", - "How are you?", - "What's your name?", - ], - "input": "What up", - }, - ), - _TestCase( - "Howdy", - { - "result": "hello", - "context": [ - "Hi there!", - "How are you?", - "What's your name?", - ], - "input": "Howdy", - }, - ), - ), - ), - ( - seq_naive_rag_alt, - ( - _TestCase( - "What up", - { - "result": "hello", - "context": [ - "Hi there!", - "How are you?", - "What's your name?", - ], - "input": "What up", - }, - ), - _TestCase( - "Howdy", - { - "result": "hello", - "context": [ - "Hi there!", - "How are you?", - "What's your name?", - ], - "input": "Howdy", - }, - ), - ), - ), - ( - seq_naive_rag_scoped, - ( - _TestCase( - "What up", - { - "result": "hello", - "context": [ - "Hi there!", - "How are you?", - "What's your name?", - ], - "input": "What up", - }, - ), - _TestCase( - "Howdy", - { - "result": "hello", - "context": [ - "Hi there!", - "How are you?", - "What's your name?", - ], - "input": "Howdy", - }, - ), - ), - ), -] - - -@pytest.mark.parametrize("runnable, cases", test_cases) -def test_context_runnables( - runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] -) -> None: - runnable = runnable if isinstance(runnable, Runnable) else runnable() - assert runnable.invoke(cases[0].input) == cases[0].output - assert runnable.batch([case.input for case in cases]) == [ - case.output for case in cases - ] - assert add(runnable.stream(cases[0].input)) == cases[0].output - - -@pytest.mark.parametrize("runnable, cases", test_cases) -async def test_context_runnables_async( - runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] -) -> None: - runnable = runnable if isinstance(runnable, Runnable) else runnable() - assert await runnable.ainvoke(cases[1].input) == cases[1].output - assert await runnable.abatch([case.input for case in cases]) == [ - case.output for case in cases - ] - assert await aadd(runnable.astream(cases[1].input)) == cases[1].output - - -def test_runnable_context_seq_key_not_found() -> None: - seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo") - - with pytest.raises(ValueError): - seq.invoke("foo") - - -def test_runnable_context_seq_key_order() -> None: - seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo") - - with pytest.raises(ValueError): - seq.invoke("foo") - - -def test_runnable_context_deadlock() -> None: - seq: Runnable = { - "bar": Context.setter("input") | Context.getter("foo"), - "foo": Context.setter("foo") | Context.getter("input"), - } | RunnablePassthrough() - - with pytest.raises(ValueError): - seq.invoke("foo") - - -def test_runnable_context_seq_key_circular_ref() -> None: - seq: Runnable = { - "bar": Context.setter(input=Context.getter("input")) - } | Context.getter("foo") - - with pytest.raises(ValueError): - seq.invoke("foo") - - -async def test_runnable_seq_streaming_chunks() -> None: - chain: Runnable = ( - PromptTemplate.from_template("{foo} {bar}") - | Context.setter("prompt") - | FakeStreamingListLLM(responses=["hello"]) - | StrOutputParser() - | { - "response": RunnablePassthrough(), - "prompt": Context.getter("prompt"), - } - ) - chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"})) - achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})] - for c in chunks: - assert c in achunks - for c in achunks: - assert c in chunks - - assert len(chunks) == 6 - assert [c for c in chunks if c.get("response")] == [ - {"response": "h"}, - {"response": "e"}, - {"response": "l"}, - {"response": "l"}, - {"response": "o"}, - ] - assert [c for c in chunks if c.get("prompt")] == [ - {"prompt": StringPromptValue(text="foo bar")}, - ]