mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 20:05:58 +00:00
Remove unused Context API
This commit is contained in:
parent
8265be4d3e
commit
a3e8a7fd17
@ -1,401 +0,0 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||
|
||||
T = TypeVar("T")
|
||||
Values = dict[Union[asyncio.Event, threading.Event], Any]
|
||||
CONTEXT_CONFIG_PREFIX = "__context__/"
|
||||
CONTEXT_CONFIG_SUFFIX_GET = "/get"
|
||||
CONTEXT_CONFIG_SUFFIX_SET = "/set"
|
||||
|
||||
|
||||
async def _asetter(done: asyncio.Event, values: Values, value: T) -> T:
|
||||
values[done] = value
|
||||
done.set()
|
||||
return value
|
||||
|
||||
|
||||
async def _agetter(done: asyncio.Event, values: Values) -> Any:
|
||||
await done.wait()
|
||||
return values[done]
|
||||
|
||||
|
||||
def _setter(done: threading.Event, values: Values, value: T) -> T:
|
||||
values[done] = value
|
||||
done.set()
|
||||
return value
|
||||
|
||||
|
||||
def _getter(done: threading.Event, values: Values) -> Any:
|
||||
done.wait()
|
||||
return values[done]
|
||||
|
||||
|
||||
def _key_from_id(id_: str) -> str:
|
||||
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
|
||||
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
|
||||
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
|
||||
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
else:
|
||||
msg = f"Invalid context config id {id_}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: list[Runnable],
|
||||
setter: Callable,
|
||||
getter: Callable,
|
||||
event_cls: Union[type[threading.Event], type[asyncio.Event]],
|
||||
) -> RunnableConfig:
|
||||
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
|
||||
return config
|
||||
|
||||
context_specs = [
|
||||
(spec, i)
|
||||
for i, step in enumerate(steps)
|
||||
for spec in step.config_specs
|
||||
if spec.id.startswith(CONTEXT_CONFIG_PREFIX)
|
||||
]
|
||||
grouped_by_key = {
|
||||
key: list(group)
|
||||
for key, group in groupby(
|
||||
sorted(context_specs, key=lambda s: s[0].id),
|
||||
key=lambda s: _key_from_id(s[0].id),
|
||||
)
|
||||
}
|
||||
deps_by_key = {
|
||||
key: {
|
||||
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
|
||||
}
|
||||
for key, group in grouped_by_key.items()
|
||||
}
|
||||
|
||||
values: Values = {}
|
||||
events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
event_cls
|
||||
)
|
||||
context_funcs: dict[str, Callable[[], Any]] = {}
|
||||
for key, group in grouped_by_key.items():
|
||||
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
|
||||
for dep in deps_by_key[key]:
|
||||
if key in deps_by_key[dep]:
|
||||
msg = f"Deadlock detected between context keys {key} and {dep}"
|
||||
raise ValueError(msg)
|
||||
if len(setters) != 1:
|
||||
msg = f"Expected exactly one setter for context key {key}"
|
||||
raise ValueError(msg)
|
||||
setter_idx = setters[0][1]
|
||||
if any(getter_idx < setter_idx for _, getter_idx in getters):
|
||||
msg = f"Context setter for key {key} must be defined after all getters."
|
||||
raise ValueError(msg)
|
||||
|
||||
if getters:
|
||||
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
|
||||
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
|
||||
|
||||
return patch_config(config, configurable=context_funcs)
|
||||
|
||||
|
||||
def aconfig_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: list[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronously patch a runnable config with context getters and setters.
|
||||
|
||||
Args:
|
||||
config: The runnable config.
|
||||
steps: The runnable steps.
|
||||
|
||||
Returns:
|
||||
The patched runnable config.
|
||||
"""
|
||||
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
|
||||
|
||||
|
||||
def config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: list[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Patch a runnable config with context getters and setters.
|
||||
|
||||
Args:
|
||||
config: The runnable config.
|
||||
steps: The runnable steps.
|
||||
|
||||
Returns:
|
||||
The patched runnable config.
|
||||
"""
|
||||
return _config_with_context(config, steps, _setter, _getter, threading.Event)
|
||||
|
||||
|
||||
@beta()
|
||||
class ContextGet(RunnableSerializable):
|
||||
"""Get a context value."""
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
key: Union[str, list[str]]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextGet({_print_keys(self.key)})"
|
||||
|
||||
@property
|
||||
def ids(self) -> list[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
keys = self.key if isinstance(self.key, list) else [self.key]
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}"
|
||||
for k in keys
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
annotation=Callable[[], Any],
|
||||
)
|
||||
for id_ in self.ids
|
||||
]
|
||||
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
|
||||
else:
|
||||
return configurable[self.ids[0]]()
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
|
||||
return dict(zip(self.key, values))
|
||||
else:
|
||||
return await configurable[self.ids[0]]()
|
||||
|
||||
|
||||
SetValue = Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input], Awaitable[Output]],
|
||||
Any,
|
||||
]
|
||||
|
||||
|
||||
def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:
|
||||
if not isinstance(value, Runnable) and not callable(value):
|
||||
return coerce_to_runnable(lambda _: value)
|
||||
return coerce_to_runnable(value)
|
||||
|
||||
|
||||
@beta()
|
||||
class ContextSet(RunnableSerializable):
|
||||
"""Set a context value."""
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
keys: Mapping[str, Optional[Runnable]]
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: Optional[str] = None,
|
||||
value: Optional[SetValue] = None,
|
||||
prefix: str = "",
|
||||
**kwargs: SetValue,
|
||||
):
|
||||
if key is not None:
|
||||
kwargs[key] = value
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
keys={
|
||||
k: _coerce_set_value(v) if v is not None else None
|
||||
for k, v in kwargs.items()
|
||||
},
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
|
||||
|
||||
@property
|
||||
def ids(self) -> list[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
|
||||
for key in self.keys
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
mapper_config_specs = [
|
||||
s
|
||||
for mapper in self.keys.values()
|
||||
if mapper is not None
|
||||
for s in mapper.config_specs
|
||||
]
|
||||
for spec in mapper_config_specs:
|
||||
if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET):
|
||||
getter_key = spec.id.split("/")[1]
|
||||
if getter_key in self.keys:
|
||||
msg = f"Circular reference in context setter for key {getter_key}"
|
||||
raise ValueError(msg)
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
annotation=Callable[[], Any],
|
||||
)
|
||||
for id_ in self.ids
|
||||
]
|
||||
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
if mapper is not None:
|
||||
configurable[id_](mapper.invoke(input, config))
|
||||
else:
|
||||
configurable[id_](input)
|
||||
return input
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
if mapper is not None:
|
||||
await configurable[id_](await mapper.ainvoke(input, config))
|
||||
else:
|
||||
await configurable[id_](input)
|
||||
return input
|
||||
|
||||
|
||||
class Context:
|
||||
"""Context for a runnable.
|
||||
|
||||
The `Context` class provides methods for creating context scopes,
|
||||
getters, and setters within a runnable. It allows for managing
|
||||
and accessing contextual information throughout the execution
|
||||
of a program.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.beta.runnables.context import Context
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
chain = (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": RunnablePassthrough()
|
||||
| Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| PromptTemplate.from_template("{context} {question}")
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"result": RunnablePassthrough(),
|
||||
"context": Context.getter("context"),
|
||||
"input": Context.getter("input"),
|
||||
}
|
||||
)
|
||||
|
||||
# Use the chain
|
||||
output = chain.invoke("What's your name?")
|
||||
print(output["result"]) # Output: "hello"
|
||||
print(output["context"]) # Output: "What's your name?"
|
||||
print(output["input"]) # Output: "What's your name?
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_scope(scope: str, /) -> "PrefixContext":
|
||||
"""Create a context scope.
|
||||
|
||||
Args:
|
||||
scope: The scope.
|
||||
|
||||
Returns:
|
||||
The context scope.
|
||||
"""
|
||||
return PrefixContext(prefix=scope)
|
||||
|
||||
@staticmethod
|
||||
def getter(key: Union[str, list[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key)
|
||||
|
||||
@staticmethod
|
||||
def setter(
|
||||
_key: Optional[str] = None,
|
||||
_value: Optional[SetValue] = None,
|
||||
/,
|
||||
**kwargs: SetValue,
|
||||
) -> ContextSet:
|
||||
return ContextSet(_key, _value, prefix="", **kwargs)
|
||||
|
||||
|
||||
class PrefixContext:
|
||||
"""Context for a runnable with a prefix."""
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
def __init__(self, prefix: str = ""):
|
||||
self.prefix = prefix
|
||||
|
||||
def getter(self, key: Union[str, list[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key, prefix=self.prefix)
|
||||
|
||||
def setter(
|
||||
self,
|
||||
_key: Optional[str] = None,
|
||||
_value: Optional[SetValue] = None,
|
||||
/,
|
||||
**kwargs: SetValue,
|
||||
) -> ContextSet:
|
||||
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)
|
||||
|
||||
|
||||
def _print_keys(keys: Union[str, Sequence[str]]) -> str:
|
||||
if isinstance(keys, str):
|
||||
return f"'{keys}'"
|
||||
else:
|
||||
return ", ".join(f"'{k}'" for k in keys)
|
@ -18,7 +18,7 @@ from collections.abc import (
|
||||
)
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
from functools import wraps
|
||||
from itertools import groupby, tee
|
||||
from itertools import tee
|
||||
from operator import itemgetter
|
||||
from types import GenericAlias
|
||||
from typing import (
|
||||
@ -2858,50 +2858,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
Returns:
|
||||
The config specs of the Runnable.
|
||||
"""
|
||||
from langchain_core.beta.runnables.context import (
|
||||
CONTEXT_CONFIG_PREFIX,
|
||||
_key_from_id,
|
||||
return get_unique_config_specs(
|
||||
[spec for step in self.steps for spec in step.config_specs]
|
||||
)
|
||||
|
||||
# get all specs
|
||||
all_specs = [
|
||||
(spec, idx)
|
||||
for idx, step in enumerate(self.steps)
|
||||
for spec in step.config_specs
|
||||
]
|
||||
# calculate context dependencies
|
||||
specs_by_pos = groupby(
|
||||
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
|
||||
itemgetter(1),
|
||||
)
|
||||
next_deps: set[str] = set()
|
||||
deps_by_pos: dict[int, set[str]] = {}
|
||||
for pos, specs in specs_by_pos:
|
||||
deps_by_pos[pos] = next_deps
|
||||
next_deps = next_deps | {spec[0].id for spec in specs}
|
||||
# assign context dependencies
|
||||
for pos, (spec, idx) in enumerate(all_specs):
|
||||
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
|
||||
all_specs[pos] = (
|
||||
ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
annotation=spec.annotation,
|
||||
name=spec.name,
|
||||
default=spec.default,
|
||||
description=spec.description,
|
||||
is_shared=spec.is_shared,
|
||||
dependencies=[
|
||||
d
|
||||
for d in deps_by_pos[idx]
|
||||
if _key_from_id(d) != _key_from_id(spec.id)
|
||||
]
|
||||
+ (spec.dependencies or []),
|
||||
),
|
||||
idx,
|
||||
)
|
||||
|
||||
return get_unique_config_specs(spec for spec, _ in all_specs)
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
"""Get the graph representation of the Runnable.
|
||||
|
||||
@ -2998,10 +2958,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
|
||||
# setup callbacks and context
|
||||
config = config_with_context(ensure_config(config), self.steps)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -3037,10 +2995,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
|
||||
# setup callbacks and context
|
||||
config = aconfig_with_context(ensure_config(config), self.steps)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@ -3082,17 +3037,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> list[Output]:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks and context
|
||||
configs = [
|
||||
config_with_context(c, self.steps)
|
||||
for c in get_config_list(config, len(inputs))
|
||||
]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@ -3209,17 +3160,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> list[Output]:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks and context
|
||||
configs = [
|
||||
aconfig_with_context(c, self.steps)
|
||||
for c in get_config_list(config, len(inputs))
|
||||
]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@ -3338,10 +3285,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
config = config_with_context(config, self.steps)
|
||||
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
@ -3365,10 +3309,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Output]:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
config = aconfig_with_context(config, self.steps)
|
||||
|
||||
# stream the last steps
|
||||
# transform the input stream of each step with the next
|
||||
|
@ -168,11 +168,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
from langchain_core.beta.runnables.context import (
|
||||
CONTEXT_CONFIG_PREFIX,
|
||||
CONTEXT_CONFIG_SUFFIX_SET,
|
||||
)
|
||||
|
||||
specs = get_unique_config_specs(
|
||||
spec
|
||||
for step in (
|
||||
@ -182,13 +177,6 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
for spec in step.config_specs
|
||||
)
|
||||
if any(
|
||||
s.id.startswith(CONTEXT_CONFIG_PREFIX)
|
||||
and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET)
|
||||
for s in specs
|
||||
):
|
||||
msg = "RunnableBranch cannot contain context setters."
|
||||
raise ValueError(msg)
|
||||
return specs
|
||||
|
||||
def invoke(
|
||||
|
@ -1,418 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, NamedTuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.beta.runnables.context import Context
|
||||
from langchain_core.language_models import FakeListLLM, FakeStreamingListLLM
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.prompt_values import StringPromptValue
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables.base import Runnable, RunnableLambda
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import aadd, add
|
||||
|
||||
|
||||
class _TestCase(NamedTuple):
|
||||
input: Any
|
||||
output: Any
|
||||
|
||||
|
||||
def seq_naive_rag() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"result": RunnablePassthrough(),
|
||||
"context": Context.getter("context"),
|
||||
"input": Context.getter("input"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def seq_naive_rag_alt() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| Context.setter("result")
|
||||
| Context.getter(["context", "input", "result"])
|
||||
)
|
||||
|
||||
|
||||
def seq_naive_rag_scoped() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
scoped = Context.create_scope("a_scope")
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
"scoped": scoped.setter("context") | scoped.getter("context"),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| Context.setter("result")
|
||||
| Context.getter(["context", "input", "result"])
|
||||
)
|
||||
|
||||
|
||||
test_cases = [
|
||||
(
|
||||
Context.setter("foo") | Context.getter("foo"),
|
||||
(
|
||||
_TestCase("foo", "foo"),
|
||||
_TestCase("bar", "bar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
Context.setter("input") | {"bar": Context.getter("input")},
|
||||
(
|
||||
_TestCase("foo", {"bar": "foo"}),
|
||||
_TestCase("bar", {"bar": "bar"}),
|
||||
),
|
||||
),
|
||||
(
|
||||
{"bar": Context.setter("input")} | Context.getter("input"),
|
||||
(
|
||||
_TestCase("foo", "foo"),
|
||||
_TestCase("bar", "bar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt", prompt_str=lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{
|
||||
"response": "hello",
|
||||
"prompt": StringPromptValue(text="foo bar"),
|
||||
"prompt_str": "foo bar",
|
||||
},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{
|
||||
"response": "hello",
|
||||
"prompt": StringPromptValue(text="bar foo"),
|
||||
"prompt_str": "bar foo",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter(prompt_str=lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt_str": "foo bar"},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt_str": "bar foo"},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt_str", lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt_str": "foo bar"},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt_str": "bar foo"},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeStreamingListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
),
|
||||
(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||
),
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag,
|
||||
(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag_alt,
|
||||
(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag_scoped,
|
||||
(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
||||
def test_context_runnables(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
||||
) -> None:
|
||||
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
||||
assert runnable.invoke(cases[0].input) == cases[0].output
|
||||
assert runnable.batch([case.input for case in cases]) == [
|
||||
case.output for case in cases
|
||||
]
|
||||
assert add(runnable.stream(cases[0].input)) == cases[0].output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
||||
async def test_context_runnables_async(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
||||
) -> None:
|
||||
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
||||
assert await runnable.ainvoke(cases[1].input) == cases[1].output
|
||||
assert await runnable.abatch([case.input for case in cases]) == [
|
||||
case.output for case in cases
|
||||
]
|
||||
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_not_found() -> None:
|
||||
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_order() -> None:
|
||||
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_deadlock() -> None:
|
||||
seq: Runnable = {
|
||||
"bar": Context.setter("input") | Context.getter("foo"),
|
||||
"foo": Context.setter("foo") | Context.getter("input"),
|
||||
} | RunnablePassthrough()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_circular_ref() -> None:
|
||||
seq: Runnable = {
|
||||
"bar": Context.setter(input=Context.getter("input"))
|
||||
} | Context.getter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
async def test_runnable_seq_streaming_chunks() -> None:
|
||||
chain: Runnable = (
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeStreamingListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
)
|
||||
chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"}))
|
||||
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
|
||||
for c in chunks:
|
||||
assert c in achunks
|
||||
for c in achunks:
|
||||
assert c in chunks
|
||||
|
||||
assert len(chunks) == 6
|
||||
assert [c for c in chunks if c.get("response")] == [
|
||||
{"response": "h"},
|
||||
{"response": "e"},
|
||||
{"response": "l"},
|
||||
{"response": "l"},
|
||||
{"response": "o"},
|
||||
]
|
||||
assert [c for c in chunks if c.get("prompt")] == [
|
||||
{"prompt": StringPromptValue(text="foo bar")},
|
||||
]
|
Loading…
Reference in New Issue
Block a user