Remove unused Context API

This commit is contained in:
William Fu-Hinthorn 2025-03-19 12:07:20 -07:00
parent 8265be4d3e
commit a3e8a7fd17
4 changed files with 7 additions and 897 deletions

View File

@ -1,401 +0,0 @@
import asyncio
import threading
from collections import defaultdict
from collections.abc import Awaitable, Mapping, Sequence
from functools import partial
from itertools import groupby
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
)
from pydantic import ConfigDict
from langchain_core._api.beta_decorator import beta
from langchain_core.runnables.base import (
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
T = TypeVar("T")
Values = dict[Union[asyncio.Event, threading.Event], Any]
CONTEXT_CONFIG_PREFIX = "__context__/"
CONTEXT_CONFIG_SUFFIX_GET = "/get"
CONTEXT_CONFIG_SUFFIX_SET = "/set"
async def _asetter(done: asyncio.Event, values: Values, value: T) -> T:
values[done] = value
done.set()
return value
async def _agetter(done: asyncio.Event, values: Values) -> Any:
await done.wait()
return values[done]
def _setter(done: threading.Event, values: Values, value: T) -> T:
values[done] = value
done.set()
return value
def _getter(done: threading.Event, values: Values) -> Any:
done.wait()
return values[done]
def _key_from_id(id_: str) -> str:
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
else:
msg = f"Invalid context config id {id_}"
raise ValueError(msg)
def _config_with_context(
config: RunnableConfig,
steps: list[Runnable],
setter: Callable,
getter: Callable,
event_cls: Union[type[threading.Event], type[asyncio.Event]],
) -> RunnableConfig:
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
return config
context_specs = [
(spec, i)
for i, step in enumerate(steps)
for spec in step.config_specs
if spec.id.startswith(CONTEXT_CONFIG_PREFIX)
]
grouped_by_key = {
key: list(group)
for key, group in groupby(
sorted(context_specs, key=lambda s: s[0].id),
key=lambda s: _key_from_id(s[0].id),
)
}
deps_by_key = {
key: {
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
}
for key, group in grouped_by_key.items()
}
values: Values = {}
events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
event_cls
)
context_funcs: dict[str, Callable[[], Any]] = {}
for key, group in grouped_by_key.items():
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
for dep in deps_by_key[key]:
if key in deps_by_key[dep]:
msg = f"Deadlock detected between context keys {key} and {dep}"
raise ValueError(msg)
if len(setters) != 1:
msg = f"Expected exactly one setter for context key {key}"
raise ValueError(msg)
setter_idx = setters[0][1]
if any(getter_idx < setter_idx for _, getter_idx in getters):
msg = f"Context setter for key {key} must be defined after all getters."
raise ValueError(msg)
if getters:
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
return patch_config(config, configurable=context_funcs)
def aconfig_with_context(
config: RunnableConfig,
steps: list[Runnable],
) -> RunnableConfig:
"""Asynchronously patch a runnable config with context getters and setters.
Args:
config: The runnable config.
steps: The runnable steps.
Returns:
The patched runnable config.
"""
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
def config_with_context(
config: RunnableConfig,
steps: list[Runnable],
) -> RunnableConfig:
"""Patch a runnable config with context getters and setters.
Args:
config: The runnable config.
steps: The runnable steps.
Returns:
The patched runnable config.
"""
return _config_with_context(config, steps, _setter, _getter, threading.Event)
@beta()
class ContextGet(RunnableSerializable):
"""Get a context value."""
prefix: str = ""
key: Union[str, list[str]]
def __str__(self) -> str:
return f"ContextGet({_print_keys(self.key)})"
@property
def ids(self) -> list[str]:
prefix = self.prefix + "/" if self.prefix else ""
keys = self.key if isinstance(self.key, list) else [self.key]
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}"
for k in keys
]
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
annotation=Callable[[], Any],
)
for id_ in self.ids
]
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
if isinstance(self.key, list):
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
else:
return configurable[self.ids[0]]()
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return dict(zip(self.key, values))
else:
return await configurable[self.ids[0]]()
SetValue = Union[
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Any,
]
def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:
if not isinstance(value, Runnable) and not callable(value):
return coerce_to_runnable(lambda _: value)
return coerce_to_runnable(value)
@beta()
class ContextSet(RunnableSerializable):
"""Set a context value."""
prefix: str = ""
keys: Mapping[str, Optional[Runnable]]
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def __init__(
self,
key: Optional[str] = None,
value: Optional[SetValue] = None,
prefix: str = "",
**kwargs: SetValue,
):
if key is not None:
kwargs[key] = value
super().__init__( # type: ignore[call-arg]
keys={
k: _coerce_set_value(v) if v is not None else None
for k, v in kwargs.items()
},
prefix=prefix,
)
def __str__(self) -> str:
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
@property
def ids(self) -> list[str]:
prefix = self.prefix + "/" if self.prefix else ""
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
for key in self.keys
]
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
mapper_config_specs = [
s
for mapper in self.keys.values()
if mapper is not None
for s in mapper.config_specs
]
for spec in mapper_config_specs:
if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET):
getter_key = spec.id.split("/")[1]
if getter_key in self.keys:
msg = f"Circular reference in context setter for key {getter_key}"
raise ValueError(msg)
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
annotation=Callable[[], Any],
)
for id_ in self.ids
]
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
configurable[id_](mapper.invoke(input, config))
else:
configurable[id_](input)
return input
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
await configurable[id_](await mapper.ainvoke(input, config))
else:
await configurable[id_](input)
return input
class Context:
"""Context for a runnable.
The `Context` class provides methods for creating context scopes,
getters, and setters within a runnable. It allows for managing
and accessing contextual information throughout the execution
of a program.
Example:
.. code-block:: python
from langchain_core.beta.runnables.context import Context
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.output_parsers.string import StrOutputParser
from tests.unit_tests.fake.llm import FakeListLLM
chain = (
Context.setter("input")
| {
"context": RunnablePassthrough()
| Context.setter("context"),
"question": RunnablePassthrough(),
}
| PromptTemplate.from_template("{context} {question}")
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"result": RunnablePassthrough(),
"context": Context.getter("context"),
"input": Context.getter("input"),
}
)
# Use the chain
output = chain.invoke("What's your name?")
print(output["result"]) # Output: "hello"
print(output["context"]) # Output: "What's your name?"
print(output["input"]) # Output: "What's your name?
"""
@staticmethod
def create_scope(scope: str, /) -> "PrefixContext":
"""Create a context scope.
Args:
scope: The scope.
Returns:
The context scope.
"""
return PrefixContext(prefix=scope)
@staticmethod
def getter(key: Union[str, list[str]], /) -> ContextGet:
return ContextGet(key=key)
@staticmethod
def setter(
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix="", **kwargs)
class PrefixContext:
"""Context for a runnable with a prefix."""
prefix: str = ""
def __init__(self, prefix: str = ""):
self.prefix = prefix
def getter(self, key: Union[str, list[str]], /) -> ContextGet:
return ContextGet(key=key, prefix=self.prefix)
def setter(
self,
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)
def _print_keys(keys: Union[str, Sequence[str]]) -> str:
if isinstance(keys, str):
return f"'{keys}'"
else:
return ", ".join(f"'{k}'" for k in keys)

View File

@ -18,7 +18,7 @@ from collections.abc import (
)
from concurrent.futures import FIRST_COMPLETED, wait
from functools import wraps
from itertools import groupby, tee
from itertools import tee
from operator import itemgetter
from types import GenericAlias
from typing import (
@ -2858,50 +2858,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
Returns:
The config specs of the Runnable.
"""
from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX,
_key_from_id,
return get_unique_config_specs(
[spec for step in self.steps for spec in step.config_specs]
)
# get all specs
all_specs = [
(spec, idx)
for idx, step in enumerate(self.steps)
for spec in step.config_specs
]
# calculate context dependencies
specs_by_pos = groupby(
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
itemgetter(1),
)
next_deps: set[str] = set()
deps_by_pos: dict[int, set[str]] = {}
for pos, specs in specs_by_pos:
deps_by_pos[pos] = next_deps
next_deps = next_deps | {spec[0].id for spec in specs}
# assign context dependencies
for pos, (spec, idx) in enumerate(all_specs):
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
all_specs[pos] = (
ConfigurableFieldSpec(
id=spec.id,
annotation=spec.annotation,
name=spec.name,
default=spec.default,
description=spec.description,
is_shared=spec.is_shared,
dependencies=[
d
for d in deps_by_pos[idx]
if _key_from_id(d) != _key_from_id(spec.id)
]
+ (spec.dependencies or []),
),
idx,
)
return get_unique_config_specs(spec for spec, _ in all_specs)
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
"""Get the graph representation of the Runnable.
@ -2998,10 +2958,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
from langchain_core.beta.runnables.context import config_with_context
# setup callbacks and context
config = config_with_context(ensure_config(config), self.steps)
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
@ -3037,10 +2995,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
from langchain_core.beta.runnables.context import aconfig_with_context
# setup callbacks and context
config = aconfig_with_context(ensure_config(config), self.steps)
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
@ -3082,17 +3037,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> list[Output]:
from langchain_core.beta.runnables.context import config_with_context
from langchain_core.callbacks.manager import CallbackManager
if not inputs:
return []
# setup callbacks and context
configs = [
config_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
@ -3209,17 +3160,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> list[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context
from langchain_core.callbacks.manager import AsyncCallbackManager
if not inputs:
return []
# setup callbacks and context
configs = [
aconfig_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
@ -3338,10 +3285,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
from langchain_core.beta.runnables.context import config_with_context
steps = [self.first] + self.middle + [self.last]
config = config_with_context(config, self.steps)
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
@ -3365,10 +3309,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context
steps = [self.first] + self.middle + [self.last]
config = aconfig_with_context(config, self.steps)
# stream the last steps
# transform the input stream of each step with the next

View File

@ -168,11 +168,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,
)
specs = get_unique_config_specs(
spec
for step in (
@ -182,13 +177,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
)
for spec in step.config_specs
)
if any(
s.id.startswith(CONTEXT_CONFIG_PREFIX)
and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET)
for s in specs
):
msg = "RunnableBranch cannot contain context setters."
raise ValueError(msg)
return specs
def invoke(

View File

@ -1,418 +0,0 @@
import asyncio
from typing import Any, Callable, NamedTuple, Union
import pytest
from langchain_core.beta.runnables.context import Context
from langchain_core.language_models import FakeListLLM, FakeStreamingListLLM
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import aadd, add
class _TestCase(NamedTuple):
input: Any
output: Any
def seq_naive_rag() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
| {
"result": RunnablePassthrough(),
"context": Context.getter("context"),
"input": Context.getter("input"),
}
)
def seq_naive_rag_alt() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)
def seq_naive_rag_scoped() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
scoped = Context.create_scope("a_scope")
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
"scoped": scoped.setter("context") | scoped.getter("context"),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)
test_cases = [
(
Context.setter("foo") | Context.getter("foo"),
(
_TestCase("foo", "foo"),
_TestCase("bar", "bar"),
),
),
(
Context.setter("input") | {"bar": Context.getter("input")},
(
_TestCase("foo", {"bar": "foo"}),
_TestCase("bar", {"bar": "bar"}),
),
),
(
{"bar": Context.setter("input")} | Context.getter("input"),
(
_TestCase("foo", "foo"),
_TestCase("bar", "bar"),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt", prompt_str=lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
"prompt_str": Context.getter("prompt_str"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{
"response": "hello",
"prompt": StringPromptValue(text="foo bar"),
"prompt_str": "foo bar",
},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{
"response": "hello",
"prompt": StringPromptValue(text="bar foo"),
"prompt_str": "bar foo",
},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter(prompt_str=lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt_str": Context.getter("prompt_str"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt_str", lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt_str": Context.getter("prompt_str"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeStreamingListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
),
(
_TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
),
_TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
),
),
),
(
seq_naive_rag,
(
_TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
_TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
(
seq_naive_rag_alt,
(
_TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
_TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
(
seq_naive_rag_scoped,
(
_TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
_TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
]
@pytest.mark.parametrize("runnable, cases", test_cases)
def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output
assert runnable.batch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output
@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert await runnable.abatch([case.input for case in cases]) == [
case.output for case in cases
]
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output
def test_runnable_context_seq_key_not_found() -> None:
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_seq_key_order() -> None:
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_deadlock() -> None:
seq: Runnable = {
"bar": Context.setter("input") | Context.getter("foo"),
"foo": Context.setter("foo") | Context.getter("input"),
} | RunnablePassthrough()
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_seq_key_circular_ref() -> None:
seq: Runnable = {
"bar": Context.setter(input=Context.getter("input"))
} | Context.getter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
async def test_runnable_seq_streaming_chunks() -> None:
chain: Runnable = (
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeStreamingListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
)
chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"}))
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks:
assert c in achunks
for c in achunks:
assert c in chunks
assert len(chunks) == 6
assert [c for c in chunks if c.get("response")] == [
{"response": "h"},
{"response": "e"},
{"response": "l"},
{"response": "l"},
{"response": "o"},
]
assert [c for c in chunks if c.get("prompt")] == [
{"prompt": StringPromptValue(text="foo bar")},
]