mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
Add .configurable_fields() and .configurable_alternatives() to expose fields of a Runnable to be configured at runtime (#11282)
This commit is contained in:
parent
5e2d5047af
commit
4df3191092
@ -57,6 +57,8 @@ from langchain.schema.runnable.config import (
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
AddableDict,
|
||||
ConfigurableField,
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
accepts_config,
|
||||
@ -64,6 +66,7 @@ from langchain.schema.runnable.utils import (
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
get_lambda_source,
|
||||
get_unique_config_specs,
|
||||
indent_lines_after_first,
|
||||
)
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
@ -122,6 +125,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return []
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
@ -129,10 +136,28 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
include = include or []
|
||||
config_specs = self.config_specs
|
||||
configurable = (
|
||||
create_model( # type: ignore[call-overload]
|
||||
"Configurable",
|
||||
**{
|
||||
spec.id: (
|
||||
spec.annotation,
|
||||
Field(
|
||||
spec.default, title=spec.name, description=spec.description
|
||||
),
|
||||
)
|
||||
for spec in config_specs
|
||||
},
|
||||
)
|
||||
if config_specs
|
||||
else None
|
||||
)
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.__class__.__name__ + "Config",
|
||||
__config__=_Config,
|
||||
**({"configurable": (configurable, None)} if configurable else {}),
|
||||
**{
|
||||
field_name: (field_type, None)
|
||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||
@ -836,7 +861,32 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
|
||||
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
pass
|
||||
def configurable_fields(
|
||||
self, **kwargs: ConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
from langchain.schema.runnable.configurable import RunnableConfigurableFields
|
||||
|
||||
for key in kwargs:
|
||||
if key not in self.__fields__:
|
||||
raise ValueError(
|
||||
f"Configuration key {key} not found in {self}: "
|
||||
"available keys are {self.__fields__.keys()}"
|
||||
)
|
||||
|
||||
return RunnableConfigurableFields(bound=self, fields=kwargs)
|
||||
|
||||
def configurable_alternatives(
|
||||
self,
|
||||
which: ConfigurableField,
|
||||
**kwargs: Runnable[Input, Output],
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
from langchain.schema.runnable.configurable import (
|
||||
RunnableConfigurableAlternatives,
|
||||
)
|
||||
|
||||
return RunnableConfigurableAlternatives(
|
||||
which=which, bound=self, alternatives=kwargs
|
||||
)
|
||||
|
||||
|
||||
class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
@ -879,6 +929,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.last.output_schema
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.steps for spec in step.config_specs
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "\n| ".join(
|
||||
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
|
||||
@ -1388,6 +1444,12 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.steps.values() for spec in step.config_specs
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
map_for_repr = ",\n ".join(
|
||||
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
|
||||
@ -1985,6 +2047,10 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
@ -2062,6 +2128,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.output_schema
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
|
@ -26,7 +26,12 @@ from langchain.schema.runnable.config import (
|
||||
get_callback_manager_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import Input, Output
|
||||
from langchain.schema.runnable.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
|
||||
class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
@ -139,6 +144,18 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
|
||||
return super().input_schema
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec
|
||||
for step in (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
+ [r for r, _ in self.branches]
|
||||
)
|
||||
for spec in step.config_specs
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
|
@ -34,6 +34,10 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class EmptyDict(TypedDict, total=False):
|
||||
pass
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
"""Configuration for a Runnable."""
|
||||
|
||||
@ -78,6 +82,13 @@ class RunnableConfig(TypedDict, total=False):
|
||||
Maximum number of times a call can recurse. If not provided, defaults to 10.
|
||||
"""
|
||||
|
||||
configurable: Dict[str, Any]
|
||||
"""
|
||||
Runtime values for attributes previously made configurable by this Runnable,
|
||||
or sub-Runnables, through .make_configurable(). Check .output_schema for
|
||||
a description of the attributes that have been made configurable.
|
||||
"""
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
empty = RunnableConfig(
|
||||
|
276
libs/langchain/langchain/schema/runnable/configurable.py
Normal file
276
libs/langchain/langchain/schema/runnable/configurable.py
Normal file
@ -0,0 +1,276 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.runnable.base import Runnable, RunnableSerializable
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
ConfigurableField,
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
gather_with_concurrency,
|
||||
)
|
||||
|
||||
|
||||
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
bound: RunnableSerializable[Input, Output]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.bound.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.bound.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.output_schema
|
||||
|
||||
@abstractmethod
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
...
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return self._prepare(config).invoke(input, config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return await self._prepare(config).ainvoke(input, config, **kwargs)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self._prepare(c) for c in configs]
|
||||
|
||||
if all(p is self.bound for p in prepared):
|
||||
return self.bound.batch(
|
||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
def invoke(
|
||||
bound: Runnable[Input, Output],
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return bound.invoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return bound.invoke(input, config, **kwargs)
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])])
|
||||
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return cast(
|
||||
List[Output], list(executor.map(invoke, prepared, inputs, configs))
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self._prepare(c) for c in configs]
|
||||
|
||||
if all(p is self.bound for p in prepared):
|
||||
return await self.bound.abatch(
|
||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
async def ainvoke(
|
||||
bound: Runnable[Input, Output],
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await bound.ainvoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return await bound.ainvoke(input, config, **kwargs)
|
||||
|
||||
coros = map(ainvoke, prepared, inputs, configs)
|
||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
return self._prepare(config).stream(input, config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for chunk in self._prepare(config).astream(input, config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
return self._prepare(config).transform(input, config, **kwargs)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for chunk in self._prepare(config).atransform(input, config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
fields: Dict[str, ConfigurableField]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return [
|
||||
ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description
|
||||
or self.bound.__fields__[field_name].field_info.description,
|
||||
annotation=spec.annotation
|
||||
or self.bound.__fields__[field_name].annotation,
|
||||
default=getattr(self.bound, field_name),
|
||||
)
|
||||
for field_name, spec in self.fields.items()
|
||||
]
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: ConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
return self.bound.configurable_fields(**{**self.fields, **kwargs})
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
config = config or {}
|
||||
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||
configurable = {
|
||||
specs_by_id[k][0]: v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
if k in specs_by_id
|
||||
}
|
||||
|
||||
if configurable:
|
||||
return self.bound.__class__(**{**self.bound.dict(), **configurable})
|
||||
else:
|
||||
return self.bound
|
||||
|
||||
|
||||
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
which: ConfigurableField
|
||||
|
||||
alternatives: Dict[str, RunnableSerializable[Input, Output]]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
alt_keys = self.alternatives.keys()
|
||||
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore
|
||||
Literal["default"],
|
||||
)
|
||||
return [
|
||||
ConfigurableFieldSpec(
|
||||
id=self.which.id,
|
||||
name=self.which.name,
|
||||
description=self.which.description,
|
||||
annotation=Union[which_keys], # type: ignore
|
||||
default="default",
|
||||
),
|
||||
*self.bound.config_specs,
|
||||
] + [s for alt in self.alternatives.values() for s in alt.config_specs]
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: ConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
return self.__class__(
|
||||
which=self.which,
|
||||
bound=self.bound.configurable_fields(**kwargs),
|
||||
alternatives=self.alternatives,
|
||||
)
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
config = config or {}
|
||||
which = config.get("configurable", {}).get(self.which.id)
|
||||
if not which:
|
||||
return self.bound
|
||||
elif which in self.alternatives:
|
||||
return self.alternatives[which]
|
||||
else:
|
||||
raise ValueError(f"Unknown alternative: {which}")
|
@ -22,7 +22,12 @@ from langchain.schema.runnable.config import (
|
||||
get_config_list,
|
||||
patch_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import Input, Output
|
||||
from langchain.schema.runnable.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||
@ -56,6 +61,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.output_schema
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec
|
||||
for step in [self.runnable, *self.fallbacks]
|
||||
for spec in step.config_specs
|
||||
)
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
|
@ -11,6 +11,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
@ -24,7 +25,7 @@ from langchain.schema.runnable.base import (
|
||||
RunnableSerializable,
|
||||
)
|
||||
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
||||
from langchain.schema.runnable.utils import AddableDict
|
||||
from langchain.schema.runnable.utils import AddableDict, ConfigurableFieldSpec
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
|
||||
@ -160,6 +161,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
return super().output_schema
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return self.mapper.config_specs
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
|
@ -8,6 +8,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -26,7 +27,11 @@ from langchain.schema.runnable.config import (
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import gather_with_concurrency
|
||||
from langchain.schema.runnable.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
gather_with_concurrency,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
|
||||
class RouterInput(TypedDict):
|
||||
@ -49,6 +54,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
|
||||
runnables: Mapping[str, Runnable[Any, Output]]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.runnables.values() for spec in step.config_specs
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
|
||||
|
@ -5,6 +5,7 @@ import asyncio
|
||||
import inspect
|
||||
import textwrap
|
||||
from inspect import signature
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
@ -13,8 +14,10 @@ from typing import (
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -211,3 +214,39 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
else:
|
||||
final = final + chunk
|
||||
return final
|
||||
|
||||
|
||||
class ConfigurableField(NamedTuple):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
annotation: Optional[Any] = None
|
||||
|
||||
|
||||
class ConfigurableFieldSpec(NamedTuple):
|
||||
id: str
|
||||
name: Optional[str]
|
||||
description: Optional[str]
|
||||
|
||||
default: Any
|
||||
annotation: Any
|
||||
|
||||
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
) -> Sequence[ConfigurableFieldSpec]:
|
||||
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
|
||||
unique: List[ConfigurableFieldSpec] = []
|
||||
for id, dupes in grouped:
|
||||
first = next(dupes)
|
||||
others = list(dupes)
|
||||
if len(others) == 0:
|
||||
unique.append(first)
|
||||
elif all(o == first for o in others):
|
||||
unique.append(first)
|
||||
else:
|
||||
raise ValueError(
|
||||
"RunnableSequence contains conflicting config specs"
|
||||
f"for {id}: {[first] + others}"
|
||||
)
|
||||
return unique
|
||||
|
@ -30,6 +30,7 @@ from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
@ -56,7 +57,7 @@ from langchain.schema.runnable import (
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain.schema.runnable.base import RunnableGenerator
|
||||
from langchain.schema.runnable.base import ConfigurableField, RunnableGenerator
|
||||
from langchain.schema.runnable.utils import add
|
||||
from langchain.tools.base import BaseTool, tool
|
||||
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
|
||||
@ -143,6 +144,15 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"title": "FakeRunnableOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
assert fake.config_schema(include=["tags", "metadata", "run_name"]).schema() == {
|
||||
"title": "FakeRunnableConfig",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {"title": "Metadata", "type": "object"},
|
||||
"run_name": {"title": "Run Name", "type": "string"},
|
||||
"tags": {"items": {"type": "string"}, "title": "Tags", "type": "array"},
|
||||
},
|
||||
}
|
||||
|
||||
fake_bound = FakeRunnable().bind(a="b") # str -> int
|
||||
|
||||
@ -538,6 +548,261 @@ def test_schema_chains() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_configurable_fields() -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
|
||||
|
||||
assert fake_llm.invoke("...") == "a"
|
||||
|
||||
fake_llm_configurable = fake_llm.configurable_fields(
|
||||
responses=ConfigurableField(
|
||||
id="llm_responses",
|
||||
name="LLM Responses",
|
||||
description="A list of fake responses for this LLM",
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_llm_configurable.invoke("...") == "a"
|
||||
|
||||
assert fake_llm_configurable.config_schema().schema() == {
|
||||
"title": "RunnableConfigurableFieldsConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"llm_responses": {
|
||||
"title": "LLM Responses",
|
||||
"description": "A list of fake responses for this LLM",
|
||||
"default": ["a"],
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
fake_llm_configured = fake_llm_configurable.with_config(
|
||||
configurable={"llm_responses": ["b"]}
|
||||
)
|
||||
|
||||
assert fake_llm_configured.invoke("...") == "b"
|
||||
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
|
||||
assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!")
|
||||
|
||||
prompt_configurable = prompt.configurable_fields(
|
||||
template=ConfigurableField(
|
||||
id="prompt_template",
|
||||
name="Prompt Template",
|
||||
description="The prompt template for this chain",
|
||||
)
|
||||
)
|
||||
|
||||
assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue(
|
||||
text="Hello, John!"
|
||||
)
|
||||
|
||||
assert prompt_configurable.config_schema().schema() == {
|
||||
"title": "RunnableConfigurableFieldsConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt_template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "The prompt template for this chain",
|
||||
"default": "Hello, {name}!",
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
prompt_configured = prompt_configurable.with_config(
|
||||
configurable={"prompt_template": "Hello, {name}! {name}!"}
|
||||
)
|
||||
|
||||
assert prompt_configured.invoke({"name": "John"}) == StringPromptValue(
|
||||
text="Hello, John! John!"
|
||||
)
|
||||
|
||||
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
|
||||
|
||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||
|
||||
assert chain_configurable.config_schema().schema() == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"llm_responses": {
|
||||
"title": "LLM Responses",
|
||||
"description": "A list of fake responses for this LLM",
|
||||
"default": ["a"],
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"prompt_template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "The prompt template for this chain",
|
||||
"default": "Hello, {name}!",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
assert (
|
||||
chain_configurable.with_config(
|
||||
configurable={
|
||||
"prompt_template": "A very good morning to you, {name}!",
|
||||
"llm_responses": ["c"],
|
||||
}
|
||||
).invoke({"name": "John"})
|
||||
== "c"
|
||||
)
|
||||
|
||||
chain_with_map_configurable: Runnable = prompt_configurable | {
|
||||
"llm1": fake_llm_configurable | StrOutputParser(),
|
||||
"llm2": fake_llm_configurable | StrOutputParser(),
|
||||
"llm3": fake_llm.configurable_fields(
|
||||
responses=ConfigurableField("other_responses")
|
||||
)
|
||||
| StrOutputParser(),
|
||||
}
|
||||
|
||||
assert chain_with_map_configurable.invoke({"name": "John"}) == {
|
||||
"llm1": "a",
|
||||
"llm2": "a",
|
||||
"llm3": "a",
|
||||
}
|
||||
|
||||
assert chain_with_map_configurable.config_schema().schema() == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"llm_responses": {
|
||||
"title": "LLM Responses",
|
||||
"description": "A list of fake responses for this LLM",
|
||||
"default": ["a"],
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"other_responses": {
|
||||
"title": "Other Responses",
|
||||
"default": ["a"],
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"prompt_template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "The prompt template for this chain",
|
||||
"default": "Hello, {name}!",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
assert chain_with_map_configurable.with_config(
|
||||
configurable={
|
||||
"prompt_template": "A very good morning to you, {name}!",
|
||||
"llm_responses": ["c"],
|
||||
"other_responses": ["d"],
|
||||
}
|
||||
).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"}
|
||||
|
||||
|
||||
def test_configurable_fields_example() -> None:
|
||||
fake_llm = (
|
||||
FakeListLLM(responses=["a"])
|
||||
.configurable_fields(
|
||||
responses=ConfigurableField(
|
||||
id="llm_responses",
|
||||
name="LLM Responses",
|
||||
description="A list of fake responses for this LLM",
|
||||
)
|
||||
)
|
||||
.configurable_alternatives(
|
||||
ConfigurableField(id="llm", name="LLM"),
|
||||
chat=FakeListChatModel(responses=["b"]) | StrOutputParser(),
|
||||
)
|
||||
)
|
||||
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
|
||||
template=ConfigurableField(
|
||||
id="prompt_template",
|
||||
name="Prompt Template",
|
||||
description="The prompt template for this chain",
|
||||
)
|
||||
)
|
||||
|
||||
chain_configurable = prompt | fake_llm
|
||||
|
||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||
|
||||
assert chain_configurable.config_schema().schema() == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"llm": {
|
||||
"title": "LLM",
|
||||
"default": "default",
|
||||
"anyOf": [
|
||||
{"enum": ["chat"], "type": "string"},
|
||||
{"enum": ["default"], "type": "string"},
|
||||
],
|
||||
},
|
||||
"llm_responses": {
|
||||
"title": "LLM Responses",
|
||||
"description": "A list of fake responses for this LLM",
|
||||
"default": ["a"],
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"prompt_template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "The prompt template for this chain",
|
||||
"default": "Hello, {name}!",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
assert (
|
||||
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
|
||||
{"name": "John"}
|
||||
)
|
||||
== "b"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_config(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
|
Loading…
Reference in New Issue
Block a user