Add .configurable_fields() and .configurable_alternatives() to expose fields of a Runnable to be configured at runtime (#11282)

This commit is contained in:
Nuno Campos 2023-10-02 21:18:36 +01:00 committed by GitHub
parent 5e2d5047af
commit 4df3191092
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 713 additions and 6 deletions

View File

@ -57,6 +57,8 @@ from langchain.schema.runnable.config import (
) )
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
AddableDict, AddableDict,
ConfigurableField,
ConfigurableFieldSpec,
Input, Input,
Output, Output,
accepts_config, accepts_config,
@ -64,6 +66,7 @@ from langchain.schema.runnable.utils import (
gather_with_concurrency, gather_with_concurrency,
get_function_first_arg_dict_keys, get_function_first_arg_dict_keys,
get_lambda_source, get_lambda_source,
get_unique_config_specs,
indent_lines_after_first, indent_lines_after_first,
) )
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
@ -122,6 +125,10 @@ class Runnable(Generic[Input, Output], ABC):
self.__class__.__name__ + "Output", __root__=(root_type, None) self.__class__.__name__ + "Output", __root__=(root_type, None)
) )
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return []
def config_schema( def config_schema(
self, *, include: Optional[Sequence[str]] = None self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:
@ -129,10 +136,28 @@ class Runnable(Generic[Input, Output], ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
include = include or [] include = include or []
config_specs = self.config_specs
configurable = (
create_model( # type: ignore[call-overload]
"Configurable",
**{
spec.id: (
spec.annotation,
Field(
spec.default, title=spec.name, description=spec.description
),
)
for spec in config_specs
},
)
if config_specs
else None
)
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
self.__class__.__name__ + "Config", self.__class__.__name__ + "Config",
__config__=_Config, __config__=_Config,
**({"configurable": (configurable, None)} if configurable else {}),
**{ **{
field_name: (field_type, None) field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items() for field_name, field_type in RunnableConfig.__annotations__.items()
@ -836,7 +861,32 @@ class Runnable(Generic[Input, Output], ABC):
class RunnableSerializable(Serializable, Runnable[Input, Output]): class RunnableSerializable(Serializable, Runnable[Input, Output]):
pass def configurable_fields(
self, **kwargs: ConfigurableField
) -> RunnableSerializable[Input, Output]:
from langchain.schema.runnable.configurable import RunnableConfigurableFields
for key in kwargs:
if key not in self.__fields__:
raise ValueError(
f"Configuration key {key} not found in {self}: "
"available keys are {self.__fields__.keys()}"
)
return RunnableConfigurableFields(bound=self, fields=kwargs)
def configurable_alternatives(
self,
which: ConfigurableField,
**kwargs: Runnable[Input, Output],
) -> RunnableSerializable[Input, Output]:
from langchain.schema.runnable.configurable import (
RunnableConfigurableAlternatives,
)
return RunnableConfigurableAlternatives(
which=which, bound=self, alternatives=kwargs
)
class RunnableSequence(RunnableSerializable[Input, Output]): class RunnableSequence(RunnableSerializable[Input, Output]):
@ -879,6 +929,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> Type[BaseModel]:
return self.last.output_schema return self.last.output_schema
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps for spec in step.config_specs
)
def __repr__(self) -> str: def __repr__(self) -> str:
return "\n| ".join( return "\n| ".join(
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ") repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
@ -1388,6 +1444,12 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
**{k: (v.OutputType, None) for k, v in self.steps.items()}, **{k: (v.OutputType, None) for k, v in self.steps.items()},
) )
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps.values() for spec in step.config_specs
)
def __repr__(self) -> str: def __repr__(self) -> str:
map_for_repr = ",\n ".join( map_for_repr = ",\n ".join(
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}" f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
@ -1985,6 +2047,10 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
), ),
) )
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs
def config_schema( def config_schema(
self, *, include: Optional[Sequence[str]] = None self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:
@ -2062,6 +2128,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema return self.bound.output_schema
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs
def config_schema( def config_schema(
self, *, include: Optional[Sequence[str]] = None self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:

View File

@ -26,7 +26,12 @@ from langchain.schema.runnable.config import (
get_callback_manager_for_config, get_callback_manager_for_config,
patch_config, patch_config,
) )
from langchain.schema.runnable.utils import Input, Output from langchain.schema.runnable.utils import (
ConfigurableFieldSpec,
Input,
Output,
get_unique_config_specs,
)
class RunnableBranch(RunnableSerializable[Input, Output]): class RunnableBranch(RunnableSerializable[Input, Output]):
@ -139,6 +144,18 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return super().input_schema return super().input_schema
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for spec in step.config_specs
)
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:

View File

@ -34,6 +34,10 @@ if TYPE_CHECKING:
) )
class EmptyDict(TypedDict, total=False):
pass
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable.""" """Configuration for a Runnable."""
@ -78,6 +82,13 @@ class RunnableConfig(TypedDict, total=False):
Maximum number of times a call can recurse. If not provided, defaults to 10. Maximum number of times a call can recurse. If not provided, defaults to 10.
""" """
configurable: Dict[str, Any]
"""
Runtime values for attributes previously made configurable by this Runnable,
or sub-Runnables, through .make_configurable(). Check .output_schema for
a description of the attributes that have been made configurable.
"""
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
empty = RunnableConfig( empty = RunnableConfig(

View File

@ -0,0 +1,276 @@
from __future__ import annotations
from abc import abstractmethod
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
from langchain.pydantic_v1 import BaseModel
from langchain.schema.runnable.base import Runnable, RunnableSerializable
from langchain.schema.runnable.config import (
RunnableConfig,
get_config_list,
get_executor_for_config,
)
from langchain.schema.runnable.utils import (
ConfigurableField,
ConfigurableFieldSpec,
Input,
Output,
gather_with_concurrency,
)
class DynamicRunnable(RunnableSerializable[Input, Output]):
bound: RunnableSerializable[Input, Output]
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def InputType(self) -> Type[Input]:
return self.bound.InputType
@property
def OutputType(self) -> Type[Output]:
return self.bound.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.bound.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema
@abstractmethod
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
...
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return self._prepare(config).invoke(input, config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return await self._prepare(config).ainvoke(input, config, **kwargs)
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs]
if all(p is self.bound for p in prepared):
return self.bound.batch(
inputs, config, return_exceptions=return_exceptions, **kwargs
)
if not inputs:
return []
configs = get_config_list(config, len(inputs))
def invoke(
bound: Runnable[Input, Output],
input: Input,
config: RunnableConfig,
) -> Union[Output, Exception]:
if return_exceptions:
try:
return bound.invoke(input, config, **kwargs)
except Exception as e:
return e
else:
return bound.invoke(input, config, **kwargs)
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])])
with get_executor_for_config(configs[0]) as executor:
return cast(
List[Output], list(executor.map(invoke, prepared, inputs, configs))
)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs]
if all(p is self.bound for p in prepared):
return await self.bound.abatch(
inputs, config, return_exceptions=return_exceptions, **kwargs
)
if not inputs:
return []
configs = get_config_list(config, len(inputs))
async def ainvoke(
bound: Runnable[Input, Output],
input: Input,
config: RunnableConfig,
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await bound.ainvoke(input, config, **kwargs)
except Exception as e:
return e
else:
return await bound.ainvoke(input, config, **kwargs)
coros = map(ainvoke, prepared, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
return self._prepare(config).stream(input, config, **kwargs)
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for chunk in self._prepare(config).astream(input, config, **kwargs):
yield chunk
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
return self._prepare(config).transform(input, config, **kwargs)
async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for chunk in self._prepare(config).atransform(input, config, **kwargs):
yield chunk
class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
fields: Dict[str, ConfigurableField]
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return [
ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description
or self.bound.__fields__[field_name].field_info.description,
annotation=spec.annotation
or self.bound.__fields__[field_name].annotation,
default=getattr(self.bound, field_name),
)
for field_name, spec in self.fields.items()
]
def configurable_fields(
self, **kwargs: ConfigurableField
) -> RunnableSerializable[Input, Output]:
return self.bound.configurable_fields(**{**self.fields, **kwargs})
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
config = config or {}
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
configurable = {
specs_by_id[k][0]: v
for k, v in config.get("configurable", {}).items()
if k in specs_by_id
}
if configurable:
return self.bound.__class__(**{**self.bound.dict(), **configurable})
else:
return self.bound
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
which: ConfigurableField
alternatives: Dict[str, RunnableSerializable[Input, Output]]
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
alt_keys = self.alternatives.keys()
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore
Literal["default"],
)
return [
ConfigurableFieldSpec(
id=self.which.id,
name=self.which.name,
description=self.which.description,
annotation=Union[which_keys], # type: ignore
default="default",
),
*self.bound.config_specs,
] + [s for alt in self.alternatives.values() for s in alt.config_specs]
def configurable_fields(
self, **kwargs: ConfigurableField
) -> RunnableSerializable[Input, Output]:
return self.__class__(
which=self.which,
bound=self.bound.configurable_fields(**kwargs),
alternatives=self.alternatives,
)
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
config = config or {}
which = config.get("configurable", {}).get(self.which.id)
if not which:
return self.bound
elif which in self.alternatives:
return self.alternatives[which]
else:
raise ValueError(f"Unknown alternative: {which}")

View File

@ -22,7 +22,12 @@ from langchain.schema.runnable.config import (
get_config_list, get_config_list,
patch_config, patch_config,
) )
from langchain.schema.runnable.utils import Input, Output from langchain.schema.runnable.utils import (
ConfigurableFieldSpec,
Input,
Output,
get_unique_config_specs,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
@ -56,6 +61,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema return self.runnable.output_schema
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in [self.runnable, *self.fallbacks]
for spec in step.config_specs
)
def config_schema( def config_schema(
self, *, include: Optional[Sequence[str]] = None self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:

View File

@ -11,6 +11,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Type, Type,
Union, Union,
cast, cast,
@ -24,7 +25,7 @@ from langchain.schema.runnable.base import (
RunnableSerializable, RunnableSerializable,
) )
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
from langchain.schema.runnable.utils import AddableDict from langchain.schema.runnable.utils import AddableDict, ConfigurableFieldSpec
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee from langchain.utils.iter import safetee
@ -160,6 +161,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return super().output_schema return super().output_schema
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.mapper.config_specs
def invoke( def invoke(
self, self,
input: Dict[str, Any], input: Dict[str, Any],

View File

@ -8,6 +8,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Union, Union,
cast, cast,
) )
@ -26,7 +27,11 @@ from langchain.schema.runnable.config import (
get_config_list, get_config_list,
get_executor_for_config, get_executor_for_config,
) )
from langchain.schema.runnable.utils import gather_with_concurrency from langchain.schema.runnable.utils import (
ConfigurableFieldSpec,
gather_with_concurrency,
get_unique_config_specs,
)
class RouterInput(TypedDict): class RouterInput(TypedDict):
@ -49,6 +54,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
runnables: Mapping[str, Runnable[Any, Output]] runnables: Mapping[str, Runnable[Any, Output]]
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.runnables.values() for spec in step.config_specs
)
def __init__( def __init__(
self, self,
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]], runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],

View File

@ -5,6 +5,7 @@ import asyncio
import inspect import inspect
import textwrap import textwrap
from inspect import signature from inspect import signature
from itertools import groupby
from typing import ( from typing import (
Any, Any,
AsyncIterable, AsyncIterable,
@ -13,8 +14,10 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
NamedTuple,
Optional, Optional,
Protocol, Protocol,
Sequence,
Set, Set,
TypeVar, TypeVar,
Union, Union,
@ -211,3 +214,39 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
else: else:
final = final + chunk final = final + chunk
return final return final
class ConfigurableField(NamedTuple):
id: str
name: Optional[str] = None
description: Optional[str] = None
annotation: Optional[Any] = None
class ConfigurableFieldSpec(NamedTuple):
id: str
name: Optional[str]
description: Optional[str]
default: Any
annotation: Any
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
) -> Sequence[ConfigurableFieldSpec]:
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
unique: List[ConfigurableFieldSpec] = []
for id, dupes in grouped:
first = next(dupes)
others = list(dupes)
if len(others) == 0:
unique.append(first)
elif all(o == first for o in others):
unique.append(first)
else:
raise ValueError(
"RunnableSequence contains conflicting config specs"
f"for {id}: {[first] + others}"
)
return unique

View File

@ -30,6 +30,7 @@ from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
from langchain.load.dump import dumpd, dumps from langchain.load.dump import dumpd, dumps
from langchain.output_parsers.list import CommaSeparatedListOutputParser from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ( from langchain.prompts.chat import (
ChatPromptTemplate, ChatPromptTemplate,
ChatPromptValue, ChatPromptValue,
@ -56,7 +57,7 @@ from langchain.schema.runnable import (
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableWithFallbacks,
) )
from langchain.schema.runnable.base import RunnableGenerator from langchain.schema.runnable.base import ConfigurableField, RunnableGenerator
from langchain.schema.runnable.utils import add from langchain.schema.runnable.utils import add
from langchain.tools.base import BaseTool, tool from langchain.tools.base import BaseTool, tool
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
@ -143,6 +144,15 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"title": "FakeRunnableOutput", "title": "FakeRunnableOutput",
"type": "integer", "type": "integer",
} }
assert fake.config_schema(include=["tags", "metadata", "run_name"]).schema() == {
"title": "FakeRunnableConfig",
"type": "object",
"properties": {
"metadata": {"title": "Metadata", "type": "object"},
"run_name": {"title": "Run Name", "type": "string"},
"tags": {"items": {"type": "string"}, "title": "Tags", "type": "array"},
},
}
fake_bound = FakeRunnable().bind(a="b") # str -> int fake_bound = FakeRunnable().bind(a="b") # str -> int
@ -538,6 +548,261 @@ def test_schema_chains() -> None:
} }
def test_configurable_fields() -> None:
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
assert fake_llm.invoke("...") == "a"
fake_llm_configurable = fake_llm.configurable_fields(
responses=ConfigurableField(
id="llm_responses",
name="LLM Responses",
description="A list of fake responses for this LLM",
)
)
assert fake_llm_configurable.invoke("...") == "a"
assert fake_llm_configurable.config_schema().schema() == {
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"llm_responses": {
"title": "LLM Responses",
"description": "A list of fake responses for this LLM",
"default": ["a"],
"type": "array",
"items": {"type": "string"},
}
},
}
},
}
fake_llm_configured = fake_llm_configurable.with_config(
configurable={"llm_responses": ["b"]}
)
assert fake_llm_configured.invoke("...") == "b"
prompt = PromptTemplate.from_template("Hello, {name}!")
assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!")
prompt_configurable = prompt.configurable_fields(
template=ConfigurableField(
id="prompt_template",
name="Prompt Template",
description="The prompt template for this chain",
)
)
assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue(
text="Hello, John!"
)
assert prompt_configurable.config_schema().schema() == {
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"prompt_template": {
"title": "Prompt Template",
"description": "The prompt template for this chain",
"default": "Hello, {name}!",
"type": "string",
}
},
}
},
}
prompt_configured = prompt_configurable.with_config(
configurable={"prompt_template": "Hello, {name}! {name}!"}
)
assert prompt_configured.invoke({"name": "John"}) == StringPromptValue(
text="Hello, John! John!"
)
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
assert chain_configurable.invoke({"name": "John"}) == "a"
assert chain_configurable.config_schema().schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"llm_responses": {
"title": "LLM Responses",
"description": "A list of fake responses for this LLM",
"default": ["a"],
"type": "array",
"items": {"type": "string"},
},
"prompt_template": {
"title": "Prompt Template",
"description": "The prompt template for this chain",
"default": "Hello, {name}!",
"type": "string",
},
},
}
},
}
assert (
chain_configurable.with_config(
configurable={
"prompt_template": "A very good morning to you, {name}!",
"llm_responses": ["c"],
}
).invoke({"name": "John"})
== "c"
)
chain_with_map_configurable: Runnable = prompt_configurable | {
"llm1": fake_llm_configurable | StrOutputParser(),
"llm2": fake_llm_configurable | StrOutputParser(),
"llm3": fake_llm.configurable_fields(
responses=ConfigurableField("other_responses")
)
| StrOutputParser(),
}
assert chain_with_map_configurable.invoke({"name": "John"}) == {
"llm1": "a",
"llm2": "a",
"llm3": "a",
}
assert chain_with_map_configurable.config_schema().schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"llm_responses": {
"title": "LLM Responses",
"description": "A list of fake responses for this LLM",
"default": ["a"],
"type": "array",
"items": {"type": "string"},
},
"other_responses": {
"title": "Other Responses",
"default": ["a"],
"type": "array",
"items": {"type": "string"},
},
"prompt_template": {
"title": "Prompt Template",
"description": "The prompt template for this chain",
"default": "Hello, {name}!",
"type": "string",
},
},
}
},
}
assert chain_with_map_configurable.with_config(
configurable={
"prompt_template": "A very good morning to you, {name}!",
"llm_responses": ["c"],
"other_responses": ["d"],
}
).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"}
def test_configurable_fields_example() -> None:
fake_llm = (
FakeListLLM(responses=["a"])
.configurable_fields(
responses=ConfigurableField(
id="llm_responses",
name="LLM Responses",
description="A list of fake responses for this LLM",
)
)
.configurable_alternatives(
ConfigurableField(id="llm", name="LLM"),
chat=FakeListChatModel(responses=["b"]) | StrOutputParser(),
)
)
prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
template=ConfigurableField(
id="prompt_template",
name="Prompt Template",
description="The prompt template for this chain",
)
)
chain_configurable = prompt | fake_llm
assert chain_configurable.invoke({"name": "John"}) == "a"
assert chain_configurable.config_schema().schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"llm": {
"title": "LLM",
"default": "default",
"anyOf": [
{"enum": ["chat"], "type": "string"},
{"enum": ["default"], "type": "string"},
],
},
"llm_responses": {
"title": "LLM Responses",
"description": "A list of fake responses for this LLM",
"default": ["a"],
"type": "array",
"items": {"type": "string"},
},
"prompt_template": {
"title": "Prompt Template",
"description": "The prompt template for this chain",
"default": "Hello, {name}!",
"type": "string",
},
},
}
},
}
assert (
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
{"name": "John"}
)
== "b"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_with_config(mocker: MockerFixture) -> None: async def test_with_config(mocker: MockerFixture) -> None:
fake = FakeRunnable() fake = FakeRunnable()