[core/minor] Runnables: Implement a context api (#14046)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Brace Sproul <braceasproul@gmail.com>
This commit is contained in:
Nuno Campos 2023-12-06 15:02:29 -08:00 committed by GitHub
parent 8f95a8206b
commit 77c38df36c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 811 additions and 16 deletions

View File

@ -30,6 +30,7 @@ from langchain_core.runnables.config import (
get_config_list, get_config_list,
patch_config, patch_config,
) )
from langchain_core.runnables.context import Context
from langchain_core.runnables.fallbacks import RunnableWithFallbacks from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.router import RouterInput, RouterRunnable from langchain_core.runnables.router import RouterInput, RouterRunnable
@ -47,6 +48,7 @@ __all__ = [
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption", "ConfigurableFieldMultiOption",
"Context",
"patch_config", "patch_config",
"RouterInput", "RouterInput",
"RouterRunnable", "RouterRunnable",

View File

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from copy import deepcopy from copy import deepcopy
from functools import partial, wraps from functools import partial, wraps
from itertools import tee from itertools import groupby, tee
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -22,6 +22,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Set,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -1401,9 +1402,46 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs( from langchain_core.runnables.context import CONTEXT_CONFIG_PREFIX, _key_from_id
spec for step in self.steps for spec in step.config_specs
# get all specs
all_specs = [
(spec, idx)
for idx, step in enumerate(self.steps)
for spec in step.config_specs
]
# calculate context dependencies
specs_by_pos = groupby(
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
lambda x: x[1],
) )
next_deps: Set[str] = set()
deps_by_pos: Dict[int, Set[str]] = {}
for pos, specs in specs_by_pos:
deps_by_pos[pos] = next_deps
next_deps = next_deps | {spec[0].id for spec in specs}
# assign context dependencies
for pos, (spec, idx) in enumerate(all_specs):
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
all_specs[pos] = (
ConfigurableFieldSpec(
id=spec.id,
annotation=spec.annotation,
name=spec.name,
default=spec.default,
description=spec.description,
is_shared=spec.is_shared,
dependencies=[
d
for d in deps_by_pos[idx]
if _key_from_id(d) != _key_from_id(spec.id)
]
+ (spec.dependencies or []),
),
idx,
)
return get_unique_config_specs(spec for spec, _ in all_specs)
def __repr__(self) -> str: def __repr__(self) -> str:
return "\n| ".join( return "\n| ".join(
@ -1456,8 +1494,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) )
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks from langchain_core.runnables.context import config_with_context
config = ensure_config(config)
# setup callbacks and context
config = config_with_context(ensure_config(config), self.steps)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -1488,8 +1528,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
# setup callbacks from langchain_core.runnables.context import aconfig_with_context
config = ensure_config(config)
# setup callbacks and context
config = aconfig_with_context(ensure_config(config), self.steps)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -1523,12 +1565,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.manager import CallbackManager
from langchain_core.runnables.context import config_with_context
if not inputs: if not inputs:
return [] return []
# setup callbacks # setup callbacks and context
configs = get_config_list(config, len(inputs)) configs = [
config_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [ callback_managers = [
CallbackManager.configure( CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -1641,15 +1687,17 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import AsyncCallbackManager
AsyncCallbackManager, from langchain_core.runnables.context import aconfig_with_context
)
if not inputs: if not inputs:
return [] return []
# setup callbacks # setup callbacks and context
configs = get_config_list(config, len(inputs)) configs = [
aconfig_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [ callback_managers = [
AsyncCallbackManager.configure( AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -1763,7 +1811,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
) -> Iterator[Output]: ) -> Iterator[Output]:
from langchain_core.runnables.context import config_with_context
steps = [self.first] + self.middle + [self.last] steps = [self.first] + self.middle + [self.last]
config = config_with_context(config, self.steps)
# transform the input stream of each step with the next # transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will # steps that don't natively support transforming an input stream will
@ -1787,7 +1838,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
from langchain_core.runnables.context import aconfig_with_context
steps = [self.first] + self.middle + [self.last] steps = [self.first] + self.middle + [self.last]
config = aconfig_with_context(config, self.steps)
# stream the last steps # stream the last steps
# transform the input stream of each step with the next # transform the input stream of each step with the next

View File

@ -26,6 +26,10 @@ from langchain_core.runnables.config import (
get_callback_manager_for_config, get_callback_manager_for_config,
patch_config, patch_config,
) )
from langchain_core.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,
)
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
ConfigurableFieldSpec, ConfigurableFieldSpec,
Input, Input,
@ -148,7 +152,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs( specs = get_unique_config_specs(
spec spec
for step in ( for step in (
[self.default] [self.default]
@ -157,6 +161,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
) )
for spec in step.config_specs for spec in step.config_specs
) )
if any(
s.id.startswith(CONTEXT_CONFIG_PREFIX)
and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET)
for s in specs
):
raise ValueError("RunnableBranch cannot contain context setters.")
return specs
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any

View File

@ -0,0 +1,313 @@
import asyncio
import threading
from collections import defaultdict
from functools import partial
from itertools import groupby
from typing import (
Any,
Awaitable,
Callable,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Type,
TypeVar,
Union,
)
from langchain_core.runnables.base import (
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import RunnableConfig, patch_config
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
T = TypeVar("T")
Values = Dict[Union[asyncio.Event, threading.Event], Any]
CONTEXT_CONFIG_PREFIX = "__context__/"
CONTEXT_CONFIG_SUFFIX_GET = "/get"
CONTEXT_CONFIG_SUFFIX_SET = "/set"
async def _asetter(done: asyncio.Event, values: Values, value: T) -> T:
values[done] = value
done.set()
return value
async def _agetter(done: asyncio.Event, values: Values) -> Any:
await done.wait()
return values[done]
def _setter(done: threading.Event, values: Values, value: T) -> T:
values[done] = value
done.set()
return value
def _getter(done: threading.Event, values: Values) -> Any:
done.wait()
return values[done]
def _key_from_id(id_: str) -> str:
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
else:
raise ValueError(f"Invalid context config id {id_}")
def _config_with_context(
config: RunnableConfig,
steps: List[Runnable],
setter: Callable,
getter: Callable,
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
) -> RunnableConfig:
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
return config
context_specs = [
(spec, i)
for i, step in enumerate(steps)
for spec in step.config_specs
if spec.id.startswith(CONTEXT_CONFIG_PREFIX)
]
grouped_by_key = {
key: list(group)
for key, group in groupby(
sorted(context_specs, key=lambda s: s[0].id),
key=lambda s: _key_from_id(s[0].id),
)
}
deps_by_key = {
key: set(
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
)
for key, group in grouped_by_key.items()
}
values: Values = {}
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
event_cls
)
context_funcs: Dict[str, Callable[[], Any]] = {}
for key, group in grouped_by_key.items():
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
for dep in deps_by_key[key]:
if key in deps_by_key[dep]:
raise ValueError(
f"Deadlock detected between context keys {key} and {dep}"
)
if len(getters) < 1:
raise ValueError(f"Expected at least one getter for context key {key}")
if len(setters) != 1:
raise ValueError(f"Expected exactly one setter for context key {key}")
setter_idx = setters[0][1]
if any(getter_idx < setter_idx for _, getter_idx in getters):
raise ValueError(
f"Context setter for key {key} must be defined after all getters."
)
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
return patch_config(config, configurable=context_funcs)
def aconfig_with_context(
config: RunnableConfig,
steps: List[Runnable],
) -> RunnableConfig:
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
def config_with_context(
config: RunnableConfig,
steps: List[Runnable],
) -> RunnableConfig:
return _config_with_context(config, steps, _setter, _getter, threading.Event)
class ContextGet(RunnableSerializable):
prefix: str = ""
key: Union[str, List[str]]
@property
def ids(self) -> List[str]:
prefix = self.prefix + "/" if self.prefix else ""
keys = self.key if isinstance(self.key, list) else [self.key]
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}"
for k in keys
]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
annotation=Callable[[], Any],
)
for id_ in self.ids
]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
config = config or {}
configurable = config.get("configurable", {})
if isinstance(self.key, list):
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
else:
return configurable[self.ids[0]]()
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = config or {}
configurable = config.get("configurable", {})
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return {key: value for key, value in zip(self.key, values)}
else:
return await configurable[self.ids[0]]()
SetValue = Union[
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Any,
]
def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:
if not isinstance(value, Runnable) and not callable(value):
return coerce_to_runnable(lambda _: value)
return coerce_to_runnable(value)
class ContextSet(RunnableSerializable):
prefix: str = ""
keys: Mapping[str, Optional[Runnable]]
class Config:
arbitrary_types_allowed = True
def __init__(
self,
key: Optional[str] = None,
value: Optional[SetValue] = None,
prefix: str = "",
**kwargs: SetValue,
):
if key is not None:
kwargs[key] = value
super().__init__(
keys={
k: _coerce_set_value(v) if v is not None else None
for k, v in kwargs.items()
},
prefix=prefix,
)
@property
def ids(self) -> List[str]:
prefix = self.prefix + "/" if self.prefix else ""
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
for key in self.keys
]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
mapper_config_specs = [
s
for mapper in self.keys.values()
if mapper is not None
for s in mapper.config_specs
]
for spec in mapper_config_specs:
if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET):
getter_key = spec.id.split("/")[1]
if getter_key in self.keys:
raise ValueError(
f"Circular reference in context setter for key {getter_key}"
)
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
annotation=Callable[[], Any],
)
for id_ in self.ids
]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
config = config or {}
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
configurable[id_](mapper.invoke(input, config))
else:
configurable[id_](input)
return input
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = config or {}
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
await configurable[id_](await mapper.ainvoke(input, config))
else:
await configurable[id_](input)
return input
class Context:
@staticmethod
def create_scope(scope: str, /) -> "PrefixContext":
return PrefixContext(prefix=scope)
@staticmethod
def getter(key: Union[str, List[str]], /) -> ContextGet:
return ContextGet(key=key)
@staticmethod
def setter(
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix="", **kwargs)
class PrefixContext:
prefix: str = ""
def __init__(self, prefix: str = ""):
self.prefix = prefix
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
return ContextGet(key=key, prefix=self.prefix)
def setter(
self,
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)

View File

@ -308,13 +308,16 @@ class ConfigurableFieldSpec(NamedTuple):
description: Optional[str] = None description: Optional[str] = None
default: Any = None default: Any = None
is_shared: bool = False is_shared: bool = False
dependencies: Optional[List[str]] = None
def get_unique_config_specs( def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec], specs: Iterable[ConfigurableFieldSpec],
) -> List[ConfigurableFieldSpec]: ) -> List[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs.""" """Get the unique config specs from a sequence of config specs."""
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id) grouped = groupby(
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
)
unique: List[ConfigurableFieldSpec] = [] unique: List[ConfigurableFieldSpec] = []
for id, dupes in grouped: for id, dupes in grouped:
first = next(dupes) first = next(dupes)

View File

@ -0,0 +1,411 @@
from typing import Any, Callable, List, NamedTuple, Union
import pytest
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.runnables.context import Context
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import aadd, add
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
class TestCase(NamedTuple):
input: Any
output: Any
def seq_naive_rag() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
| {
"result": RunnablePassthrough(),
"context": Context.getter("context"),
"input": Context.getter("input"),
}
)
def seq_naive_rag_alt() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)
def seq_naive_rag_scoped() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
scoped = Context.create_scope("a_scope")
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
"scoped": scoped.setter("context") | scoped.getter("context"),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)
test_cases = [
(
Context.setter("foo") | Context.getter("foo"),
(
TestCase("foo", "foo"),
TestCase("bar", "bar"),
),
),
(
Context.setter("input") | {"bar": Context.getter("input")},
(
TestCase("foo", {"bar": "foo"}),
TestCase("bar", {"bar": "bar"}),
),
),
(
{"bar": Context.setter("input")} | Context.getter("input"),
(
TestCase("foo", "foo"),
TestCase("bar", "bar"),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt", prompt_str=lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
"prompt_str": Context.getter("prompt_str"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{
"response": "hello",
"prompt": StringPromptValue(text="foo bar"),
"prompt_str": "foo bar",
},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{
"response": "hello",
"prompt": StringPromptValue(text="bar foo"),
"prompt_str": "bar foo",
},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter(prompt_str=lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt_str": Context.getter("prompt_str"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt_str", lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt_str": Context.getter("prompt_str"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeStreamingListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
),
),
),
(
seq_naive_rag,
(
TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
(
seq_naive_rag_alt,
(
TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
(
seq_naive_rag_scoped,
(
TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
]
@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert runnable.batch([case.input for case in cases]) == [
case.output for case in cases
]
assert await runnable.abatch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output
def test_runnable_context_seq_key_not_found() -> None:
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_seq_key_order() -> None:
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_deadlock() -> None:
seq: Runnable = {
"bar": Context.setter("input") | Context.getter("foo"),
"foo": Context.setter("foo") | Context.getter("input"),
} | RunnablePassthrough()
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_seq_key_circular_ref() -> None:
seq: Runnable = {
"bar": Context.setter(input=Context.getter("input"))
} | Context.getter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
async def test_runnable_seq_streaming_chunks() -> None:
chain: Runnable = (
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeStreamingListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
)
chunks = [c for c in chain.stream({"foo": "foo", "bar": "bar"})]
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks:
assert c in achunks
for c in achunks:
assert c in chunks
assert len(chunks) == 6
assert [c for c in chunks if c.get("response")] == [
{"response": "h"},
{"response": "e"},
{"response": "l"},
{"response": "l"},
{"response": "o"},
]
assert [c for c in chunks if c.get("prompt")] == [
{"prompt": StringPromptValue(text="foo bar")},
]

View File

@ -2,6 +2,7 @@ from langchain_core.runnables import __all__
EXPECTED_ALL = [ EXPECTED_ALL = [
"AddableDict", "AddableDict",
"Context",
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption", "ConfigurableFieldMultiOption",