Remove unused Context API

This commit is contained in:
William Fu-Hinthorn 2025-03-19 12:07:20 -07:00
parent 8265be4d3e
commit a3e8a7fd17
4 changed files with 7 additions and 897 deletions

View File

@ -1,401 +0,0 @@
import asyncio
import threading
from collections import defaultdict
from collections.abc import Awaitable, Mapping, Sequence
from functools import partial
from itertools import groupby
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
)
from pydantic import ConfigDict
from langchain_core._api.beta_decorator import beta
from langchain_core.runnables.base import (
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
T = TypeVar("T")
Values = dict[Union[asyncio.Event, threading.Event], Any]
CONTEXT_CONFIG_PREFIX = "__context__/"
CONTEXT_CONFIG_SUFFIX_GET = "/get"
CONTEXT_CONFIG_SUFFIX_SET = "/set"
async def _asetter(done: asyncio.Event, values: Values, value: T) -> T:
values[done] = value
done.set()
return value
async def _agetter(done: asyncio.Event, values: Values) -> Any:
await done.wait()
return values[done]
def _setter(done: threading.Event, values: Values, value: T) -> T:
values[done] = value
done.set()
return value
def _getter(done: threading.Event, values: Values) -> Any:
done.wait()
return values[done]
def _key_from_id(id_: str) -> str:
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
else:
msg = f"Invalid context config id {id_}"
raise ValueError(msg)
def _config_with_context(
config: RunnableConfig,
steps: list[Runnable],
setter: Callable,
getter: Callable,
event_cls: Union[type[threading.Event], type[asyncio.Event]],
) -> RunnableConfig:
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
return config
context_specs = [
(spec, i)
for i, step in enumerate(steps)
for spec in step.config_specs
if spec.id.startswith(CONTEXT_CONFIG_PREFIX)
]
grouped_by_key = {
key: list(group)
for key, group in groupby(
sorted(context_specs, key=lambda s: s[0].id),
key=lambda s: _key_from_id(s[0].id),
)
}
deps_by_key = {
key: {
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
}
for key, group in grouped_by_key.items()
}
values: Values = {}
events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
event_cls
)
context_funcs: dict[str, Callable[[], Any]] = {}
for key, group in grouped_by_key.items():
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
for dep in deps_by_key[key]:
if key in deps_by_key[dep]:
msg = f"Deadlock detected between context keys {key} and {dep}"
raise ValueError(msg)
if len(setters) != 1:
msg = f"Expected exactly one setter for context key {key}"
raise ValueError(msg)
setter_idx = setters[0][1]
if any(getter_idx < setter_idx for _, getter_idx in getters):
msg = f"Context setter for key {key} must be defined after all getters."
raise ValueError(msg)
if getters:
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
return patch_config(config, configurable=context_funcs)
def aconfig_with_context(
config: RunnableConfig,
steps: list[Runnable],
) -> RunnableConfig:
"""Asynchronously patch a runnable config with context getters and setters.
Args:
config: The runnable config.
steps: The runnable steps.
Returns:
The patched runnable config.
"""
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
def config_with_context(
config: RunnableConfig,
steps: list[Runnable],
) -> RunnableConfig:
"""Patch a runnable config with context getters and setters.
Args:
config: The runnable config.
steps: The runnable steps.
Returns:
The patched runnable config.
"""
return _config_with_context(config, steps, _setter, _getter, threading.Event)
@beta()
class ContextGet(RunnableSerializable):
"""Get a context value."""
prefix: str = ""
key: Union[str, list[str]]
def __str__(self) -> str:
return f"ContextGet({_print_keys(self.key)})"
@property
def ids(self) -> list[str]:
prefix = self.prefix + "/" if self.prefix else ""
keys = self.key if isinstance(self.key, list) else [self.key]
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}"
for k in keys
]
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
annotation=Callable[[], Any],
)
for id_ in self.ids
]
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
if isinstance(self.key, list):
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
else:
return configurable[self.ids[0]]()
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return dict(zip(self.key, values))
else:
return await configurable[self.ids[0]]()
SetValue = Union[
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Any,
]
def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:
if not isinstance(value, Runnable) and not callable(value):
return coerce_to_runnable(lambda _: value)
return coerce_to_runnable(value)
@beta()
class ContextSet(RunnableSerializable):
"""Set a context value."""
prefix: str = ""
keys: Mapping[str, Optional[Runnable]]
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def __init__(
self,
key: Optional[str] = None,
value: Optional[SetValue] = None,
prefix: str = "",
**kwargs: SetValue,
):
if key is not None:
kwargs[key] = value
super().__init__( # type: ignore[call-arg]
keys={
k: _coerce_set_value(v) if v is not None else None
for k, v in kwargs.items()
},
prefix=prefix,
)
def __str__(self) -> str:
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
@property
def ids(self) -> list[str]:
prefix = self.prefix + "/" if self.prefix else ""
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
for key in self.keys
]
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
mapper_config_specs = [
s
for mapper in self.keys.values()
if mapper is not None
for s in mapper.config_specs
]
for spec in mapper_config_specs:
if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET):
getter_key = spec.id.split("/")[1]
if getter_key in self.keys:
msg = f"Circular reference in context setter for key {getter_key}"
raise ValueError(msg)
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
annotation=Callable[[], Any],
)
for id_ in self.ids
]
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
configurable[id_](mapper.invoke(input, config))
else:
configurable[id_](input)
return input
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
await configurable[id_](await mapper.ainvoke(input, config))
else:
await configurable[id_](input)
return input
class Context:
"""Context for a runnable.
The `Context` class provides methods for creating context scopes,
getters, and setters within a runnable. It allows for managing
and accessing contextual information throughout the execution
of a program.
Example:
.. code-block:: python
from langchain_core.beta.runnables.context import Context
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.output_parsers.string import StrOutputParser
from tests.unit_tests.fake.llm import FakeListLLM
chain = (
Context.setter("input")
| {
"context": RunnablePassthrough()
| Context.setter("context"),
"question": RunnablePassthrough(),
}
| PromptTemplate.from_template("{context} {question}")
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"result": RunnablePassthrough(),
"context": Context.getter("context"),
"input": Context.getter("input"),
}
)
# Use the chain
output = chain.invoke("What's your name?")
print(output["result"]) # Output: "hello"
print(output["context"]) # Output: "What's your name?"
print(output["input"]) # Output: "What's your name?
"""
@staticmethod
def create_scope(scope: str, /) -> "PrefixContext":
"""Create a context scope.
Args:
scope: The scope.
Returns:
The context scope.
"""
return PrefixContext(prefix=scope)
@staticmethod
def getter(key: Union[str, list[str]], /) -> ContextGet:
return ContextGet(key=key)
@staticmethod
def setter(
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix="", **kwargs)
class PrefixContext:
"""Context for a runnable with a prefix."""
prefix: str = ""
def __init__(self, prefix: str = ""):
self.prefix = prefix
def getter(self, key: Union[str, list[str]], /) -> ContextGet:
return ContextGet(key=key, prefix=self.prefix)
def setter(
self,
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)
def _print_keys(keys: Union[str, Sequence[str]]) -> str:
if isinstance(keys, str):
return f"'{keys}'"
else:
return ", ".join(f"'{k}'" for k in keys)

View File

@ -18,7 +18,7 @@ from collections.abc import (
) )
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from functools import wraps from functools import wraps
from itertools import groupby, tee from itertools import tee
from operator import itemgetter from operator import itemgetter
from types import GenericAlias from types import GenericAlias
from typing import ( from typing import (
@ -2858,50 +2858,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
Returns: Returns:
The config specs of the Runnable. The config specs of the Runnable.
""" """
from langchain_core.beta.runnables.context import ( return get_unique_config_specs(
CONTEXT_CONFIG_PREFIX, [spec for step in self.steps for spec in step.config_specs]
_key_from_id,
) )
# get all specs
all_specs = [
(spec, idx)
for idx, step in enumerate(self.steps)
for spec in step.config_specs
]
# calculate context dependencies
specs_by_pos = groupby(
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
itemgetter(1),
)
next_deps: set[str] = set()
deps_by_pos: dict[int, set[str]] = {}
for pos, specs in specs_by_pos:
deps_by_pos[pos] = next_deps
next_deps = next_deps | {spec[0].id for spec in specs}
# assign context dependencies
for pos, (spec, idx) in enumerate(all_specs):
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
all_specs[pos] = (
ConfigurableFieldSpec(
id=spec.id,
annotation=spec.annotation,
name=spec.name,
default=spec.default,
description=spec.description,
is_shared=spec.is_shared,
dependencies=[
d
for d in deps_by_pos[idx]
if _key_from_id(d) != _key_from_id(spec.id)
]
+ (spec.dependencies or []),
),
idx,
)
return get_unique_config_specs(spec for spec, _ in all_specs)
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
"""Get the graph representation of the Runnable. """Get the graph representation of the Runnable.
@ -2998,10 +2958,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
from langchain_core.beta.runnables.context import config_with_context
# setup callbacks and context # setup callbacks and context
config = config_with_context(ensure_config(config), self.steps) config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -3037,10 +2995,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
from langchain_core.beta.runnables.context import aconfig_with_context config = ensure_config(config)
# setup callbacks and context
config = aconfig_with_context(ensure_config(config), self.steps)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -3082,17 +3037,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> list[Output]: ) -> list[Output]:
from langchain_core.beta.runnables.context import config_with_context
from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.manager import CallbackManager
if not inputs: if not inputs:
return [] return []
# setup callbacks and context # setup callbacks and context
configs = [ configs = get_config_list(config, len(inputs))
config_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [ callback_managers = [
CallbackManager.configure( CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -3209,17 +3160,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> list[Output]: ) -> list[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context
from langchain_core.callbacks.manager import AsyncCallbackManager from langchain_core.callbacks.manager import AsyncCallbackManager
if not inputs: if not inputs:
return [] return []
# setup callbacks and context # setup callbacks and context
configs = [ configs = get_config_list(config, len(inputs))
aconfig_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [ callback_managers = [
AsyncCallbackManager.configure( AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -3338,10 +3285,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Output]: ) -> Iterator[Output]:
from langchain_core.beta.runnables.context import config_with_context
steps = [self.first] + self.middle + [self.last] steps = [self.first] + self.middle + [self.last]
config = config_with_context(config, self.steps)
# transform the input stream of each step with the next # transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will # steps that don't natively support transforming an input stream will
@ -3365,10 +3309,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context
steps = [self.first] + self.middle + [self.last] steps = [self.first] + self.middle + [self.last]
config = aconfig_with_context(config, self.steps)
# stream the last steps # stream the last steps
# transform the input stream of each step with the next # transform the input stream of each step with the next

View File

@ -168,11 +168,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
@property @property
def config_specs(self) -> list[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,
)
specs = get_unique_config_specs( specs = get_unique_config_specs(
spec spec
for step in ( for step in (
@ -182,13 +177,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
) )
for spec in step.config_specs for spec in step.config_specs
) )
if any(
s.id.startswith(CONTEXT_CONFIG_PREFIX)
and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET)
for s in specs
):
msg = "RunnableBranch cannot contain context setters."
raise ValueError(msg)
return specs return specs
def invoke( def invoke(

View File

@ -1,418 +0,0 @@
import asyncio
from typing import Any, Callable, NamedTuple, Union
import pytest
from langchain_core.beta.runnables.context import Context
from langchain_core.language_models import FakeListLLM, FakeStreamingListLLM
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import aadd, add
class _TestCase(NamedTuple):
input: Any
output: Any
def seq_naive_rag() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
| {
"result": RunnablePassthrough(),
"context": Context.getter("context"),
"input": Context.getter("input"),
}
)
def seq_naive_rag_alt() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)
def seq_naive_rag_scoped() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
scoped = Context.create_scope("a_scope")
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
"scoped": scoped.setter("context") | scoped.getter("context"),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)
test_cases = [
(
Context.setter("foo") | Context.getter("foo"),
(
_TestCase("foo", "foo"),
_TestCase("bar", "bar"),
),
),
(
Context.setter("input") | {"bar": Context.getter("input")},
(
_TestCase("foo", {"bar": "foo"}),
_TestCase("bar", {"bar": "bar"}),
),
),
(
{"bar": Context.setter("input")} | Context.getter("input"),
(
_TestCase("foo", "foo"),
_TestCase("bar", "bar"),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt", prompt_str=lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
"prompt_str": Context.getter("prompt_str"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{
"response": "hello",
"prompt": StringPromptValue(text="foo bar"),
"prompt_str": "foo bar",
},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{
"response": "hello",
"prompt": StringPromptValue(text="bar foo"),
"prompt_str": "bar foo",
},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter(prompt_str=lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt_str": Context.getter("prompt_str"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt_str", lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt_str": Context.getter("prompt_str"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeStreamingListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
),
),
),
(
seq_naive_rag,
(
_TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
_TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
(
seq_naive_rag_alt,
(
_TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
_TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
(
seq_naive_rag_scoped,
(
_TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
_TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
]
@pytest.mark.parametrize("runnable, cases", test_cases)
def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output
assert runnable.batch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output
@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert await runnable.abatch([case.input for case in cases]) == [
case.output for case in cases
]
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output
def test_runnable_context_seq_key_not_found() -> None:
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_seq_key_order() -> None:
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_deadlock() -> None:
seq: Runnable = {
"bar": Context.setter("input") | Context.getter("foo"),
"foo": Context.setter("foo") | Context.getter("input"),
} | RunnablePassthrough()
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_seq_key_circular_ref() -> None:
seq: Runnable = {
"bar": Context.setter(input=Context.getter("input"))
} | Context.getter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
async def test_runnable_seq_streaming_chunks() -> None:
chain: Runnable = (
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeStreamingListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
)
chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"}))
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks:
assert c in achunks
for c in achunks:
assert c in chunks
assert len(chunks) == 6
assert [c for c in chunks if c.get("response")] == [
{"response": "h"},
{"response": "e"},
{"response": "l"},
{"response": "l"},
{"response": "o"},
]
assert [c for c in chunks if c.get("prompt")] == [
{"prompt": StringPromptValue(text="foo bar")},
]