From 4df31910924c61683526ab5af288e6e44ef32dc9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 21:18:36 +0100 Subject: [PATCH] Add .configurable_fields() and .configurable_alternatives() to expose fields of a Runnable to be configured at runtime (#11282) --- .../langchain/schema/runnable/base.py | 72 ++++- .../langchain/schema/runnable/branch.py | 19 +- .../langchain/schema/runnable/config.py | 11 + .../langchain/schema/runnable/configurable.py | 276 ++++++++++++++++++ .../langchain/schema/runnable/fallbacks.py | 15 +- .../langchain/schema/runnable/passthrough.py | 7 +- .../langchain/schema/runnable/router.py | 13 +- .../langchain/schema/runnable/utils.py | 39 +++ .../schema/runnable/test_runnable.py | 267 ++++++++++++++++- 9 files changed, 713 insertions(+), 6 deletions(-) create mode 100644 libs/langchain/langchain/schema/runnable/configurable.py diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 0d3d38aa8e1..77d9b472ec6 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -57,6 +57,8 @@ from langchain.schema.runnable.config import ( ) from langchain.schema.runnable.utils import ( AddableDict, + ConfigurableField, + ConfigurableFieldSpec, Input, Output, accepts_config, @@ -64,6 +66,7 @@ from langchain.schema.runnable.utils import ( gather_with_concurrency, get_function_first_arg_dict_keys, get_lambda_source, + get_unique_config_specs, indent_lines_after_first, ) from langchain.utils.aiter import atee, py_anext @@ -122,6 +125,10 @@ class Runnable(Generic[Input, Output], ABC): self.__class__.__name__ + "Output", __root__=(root_type, None) ) + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return [] + def config_schema( self, *, include: Optional[Sequence[str]] = None ) -> Type[BaseModel]: @@ -129,10 +136,28 @@ class Runnable(Generic[Input, Output], ABC): arbitrary_types_allowed = True include = include or [] + config_specs = self.config_specs + configurable = ( + create_model( # type: ignore[call-overload] + "Configurable", + **{ + spec.id: ( + spec.annotation, + Field( + spec.default, title=spec.name, description=spec.description + ), + ) + for spec in config_specs + }, + ) + if config_specs + else None + ) return create_model( # type: ignore[call-overload] self.__class__.__name__ + "Config", __config__=_Config, + **({"configurable": (configurable, None)} if configurable else {}), **{ field_name: (field_type, None) for field_name, field_type in RunnableConfig.__annotations__.items() @@ -836,7 +861,32 @@ class Runnable(Generic[Input, Output], ABC): class RunnableSerializable(Serializable, Runnable[Input, Output]): - pass + def configurable_fields( + self, **kwargs: ConfigurableField + ) -> RunnableSerializable[Input, Output]: + from langchain.schema.runnable.configurable import RunnableConfigurableFields + + for key in kwargs: + if key not in self.__fields__: + raise ValueError( + f"Configuration key {key} not found in {self}: " + "available keys are {self.__fields__.keys()}" + ) + + return RunnableConfigurableFields(bound=self, fields=kwargs) + + def configurable_alternatives( + self, + which: ConfigurableField, + **kwargs: Runnable[Input, Output], + ) -> RunnableSerializable[Input, Output]: + from langchain.schema.runnable.configurable import ( + RunnableConfigurableAlternatives, + ) + + return RunnableConfigurableAlternatives( + which=which, bound=self, alternatives=kwargs + ) class RunnableSequence(RunnableSerializable[Input, Output]): @@ -879,6 +929,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def output_schema(self) -> Type[BaseModel]: return self.last.output_schema + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec for step in self.steps for spec in step.config_specs + ) + def __repr__(self) -> str: return "\n| ".join( repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ") @@ -1388,6 +1444,12 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]): **{k: (v.OutputType, None) for k, v in self.steps.items()}, ) + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec for step in self.steps.values() for spec in step.config_specs + ) + def __repr__(self) -> str: map_for_repr = ",\n ".join( f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}" @@ -1985,6 +2047,10 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): ), ) + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return self.bound.config_specs + def config_schema( self, *, include: Optional[Sequence[str]] = None ) -> Type[BaseModel]: @@ -2062,6 +2128,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]): def output_schema(self) -> Type[BaseModel]: return self.bound.output_schema + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return self.bound.config_specs + def config_schema( self, *, include: Optional[Sequence[str]] = None ) -> Type[BaseModel]: diff --git a/libs/langchain/langchain/schema/runnable/branch.py b/libs/langchain/langchain/schema/runnable/branch.py index 105582f0c2f..18e7aa1767f 100644 --- a/libs/langchain/langchain/schema/runnable/branch.py +++ b/libs/langchain/langchain/schema/runnable/branch.py @@ -26,7 +26,12 @@ from langchain.schema.runnable.config import ( get_callback_manager_for_config, patch_config, ) -from langchain.schema.runnable.utils import Input, Output +from langchain.schema.runnable.utils import ( + ConfigurableFieldSpec, + Input, + Output, + get_unique_config_specs, +) class RunnableBranch(RunnableSerializable[Input, Output]): @@ -139,6 +144,18 @@ class RunnableBranch(RunnableSerializable[Input, Output]): return super().input_schema + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec + for step in ( + [self.default] + + [r for _, r in self.branches] + + [r for r, _ in self.branches] + ) + for spec in step.config_specs + ) + def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 06d979cff08..c7bf9ddc2bf 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -34,6 +34,10 @@ if TYPE_CHECKING: ) +class EmptyDict(TypedDict, total=False): + pass + + class RunnableConfig(TypedDict, total=False): """Configuration for a Runnable.""" @@ -78,6 +82,13 @@ class RunnableConfig(TypedDict, total=False): Maximum number of times a call can recurse. If not provided, defaults to 10. """ + configurable: Dict[str, Any] + """ + Runtime values for attributes previously made configurable by this Runnable, + or sub-Runnables, through .make_configurable(). Check .output_schema for + a description of the attributes that have been made configurable. + """ + def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: empty = RunnableConfig( diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py new file mode 100644 index 00000000000..1afdcc0ef9b --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Type, + Union, + cast, +) + +from langchain.pydantic_v1 import BaseModel +from langchain.schema.runnable.base import Runnable, RunnableSerializable +from langchain.schema.runnable.config import ( + RunnableConfig, + get_config_list, + get_executor_for_config, +) +from langchain.schema.runnable.utils import ( + ConfigurableField, + ConfigurableFieldSpec, + Input, + Output, + gather_with_concurrency, +) + + +class DynamicRunnable(RunnableSerializable[Input, Output]): + bound: RunnableSerializable[Input, Output] + + class Config: + arbitrary_types_allowed = True + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + @property + def InputType(self) -> Type[Input]: + return self.bound.InputType + + @property + def OutputType(self) -> Type[Output]: + return self.bound.OutputType + + @property + def input_schema(self) -> Type[BaseModel]: + return self.bound.input_schema + + @property + def output_schema(self) -> Type[BaseModel]: + return self.bound.output_schema + + @abstractmethod + def _prepare( + self, config: Optional[RunnableConfig] = None + ) -> Runnable[Input, Output]: + ... + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + return self._prepare(config).invoke(input, config, **kwargs) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + return await self._prepare(config).ainvoke(input, config, **kwargs) + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + configs = get_config_list(config, len(inputs)) + prepared = [self._prepare(c) for c in configs] + + if all(p is self.bound for p in prepared): + return self.bound.batch( + inputs, config, return_exceptions=return_exceptions, **kwargs + ) + + if not inputs: + return [] + + configs = get_config_list(config, len(inputs)) + + def invoke( + bound: Runnable[Input, Output], + input: Input, + config: RunnableConfig, + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return bound.invoke(input, config, **kwargs) + except Exception as e: + return e + else: + return bound.invoke(input, config, **kwargs) + + # If there's only one input, don't bother with the executor + if len(inputs) == 1: + return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])]) + + with get_executor_for_config(configs[0]) as executor: + return cast( + List[Output], list(executor.map(invoke, prepared, inputs, configs)) + ) + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + configs = get_config_list(config, len(inputs)) + prepared = [self._prepare(c) for c in configs] + + if all(p is self.bound for p in prepared): + return await self.bound.abatch( + inputs, config, return_exceptions=return_exceptions, **kwargs + ) + + if not inputs: + return [] + + configs = get_config_list(config, len(inputs)) + + async def ainvoke( + bound: Runnable[Input, Output], + input: Input, + config: RunnableConfig, + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return await bound.ainvoke(input, config, **kwargs) + except Exception as e: + return e + else: + return await bound.ainvoke(input, config, **kwargs) + + coros = map(ainvoke, prepared, inputs, configs) + return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + return self._prepare(config).stream(input, config, **kwargs) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async for chunk in self._prepare(config).astream(input, config, **kwargs): + yield chunk + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + return self._prepare(config).transform(input, config, **kwargs) + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async for chunk in self._prepare(config).atransform(input, config, **kwargs): + yield chunk + + +class RunnableConfigurableFields(DynamicRunnable[Input, Output]): + fields: Dict[str, ConfigurableField] + + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return [ + ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description + or self.bound.__fields__[field_name].field_info.description, + annotation=spec.annotation + or self.bound.__fields__[field_name].annotation, + default=getattr(self.bound, field_name), + ) + for field_name, spec in self.fields.items() + ] + + def configurable_fields( + self, **kwargs: ConfigurableField + ) -> RunnableSerializable[Input, Output]: + return self.bound.configurable_fields(**{**self.fields, **kwargs}) + + def _prepare( + self, config: Optional[RunnableConfig] = None + ) -> Runnable[Input, Output]: + config = config or {} + specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} + configurable = { + specs_by_id[k][0]: v + for k, v in config.get("configurable", {}).items() + if k in specs_by_id + } + + if configurable: + return self.bound.__class__(**{**self.bound.dict(), **configurable}) + else: + return self.bound + + +class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): + which: ConfigurableField + + alternatives: Dict[str, RunnableSerializable[Input, Output]] + + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + alt_keys = self.alternatives.keys() + which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore + Literal["default"], + ) + return [ + ConfigurableFieldSpec( + id=self.which.id, + name=self.which.name, + description=self.which.description, + annotation=Union[which_keys], # type: ignore + default="default", + ), + *self.bound.config_specs, + ] + [s for alt in self.alternatives.values() for s in alt.config_specs] + + def configurable_fields( + self, **kwargs: ConfigurableField + ) -> RunnableSerializable[Input, Output]: + return self.__class__( + which=self.which, + bound=self.bound.configurable_fields(**kwargs), + alternatives=self.alternatives, + ) + + def _prepare( + self, config: Optional[RunnableConfig] = None + ) -> Runnable[Input, Output]: + config = config or {} + which = config.get("configurable", {}).get(self.which.id) + if not which: + return self.bound + elif which in self.alternatives: + return self.alternatives[which] + else: + raise ValueError(f"Unknown alternative: {which}") diff --git a/libs/langchain/langchain/schema/runnable/fallbacks.py b/libs/langchain/langchain/schema/runnable/fallbacks.py index bba8d9a9e11..239800e1bb3 100644 --- a/libs/langchain/langchain/schema/runnable/fallbacks.py +++ b/libs/langchain/langchain/schema/runnable/fallbacks.py @@ -22,7 +22,12 @@ from langchain.schema.runnable.config import ( get_config_list, patch_config, ) -from langchain.schema.runnable.utils import Input, Output +from langchain.schema.runnable.utils import ( + ConfigurableFieldSpec, + Input, + Output, + get_unique_config_specs, +) if TYPE_CHECKING: from langchain.callbacks.manager import AsyncCallbackManagerForChainRun @@ -56,6 +61,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): def output_schema(self) -> Type[BaseModel]: return self.runnable.output_schema + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec + for step in [self.runnable, *self.fallbacks] + for spec in step.config_specs + ) + def config_schema( self, *, include: Optional[Sequence[str]] = None ) -> Type[BaseModel]: diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 1d1b046a572..874123c3940 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -11,6 +11,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Type, Union, cast, @@ -24,7 +25,7 @@ from langchain.schema.runnable.base import ( RunnableSerializable, ) from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config -from langchain.schema.runnable.utils import AddableDict +from langchain.schema.runnable.utils import AddableDict, ConfigurableFieldSpec from langchain.utils.aiter import atee, py_anext from langchain.utils.iter import safetee @@ -160,6 +161,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): return super().output_schema + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return self.mapper.config_specs + def invoke( self, input: Dict[str, Any], diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index 9638235fc87..3f44ad4279b 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -8,6 +8,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Union, cast, ) @@ -26,7 +27,11 @@ from langchain.schema.runnable.config import ( get_config_list, get_executor_for_config, ) -from langchain.schema.runnable.utils import gather_with_concurrency +from langchain.schema.runnable.utils import ( + ConfigurableFieldSpec, + gather_with_concurrency, + get_unique_config_specs, +) class RouterInput(TypedDict): @@ -49,6 +54,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): runnables: Mapping[str, Runnable[Any, Output]] + @property + def config_specs(self) -> Sequence[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec for step in self.runnables.values() for spec in step.config_specs + ) + def __init__( self, runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]], diff --git a/libs/langchain/langchain/schema/runnable/utils.py b/libs/langchain/langchain/schema/runnable/utils.py index 37403f8ea31..8be6d0756e1 100644 --- a/libs/langchain/langchain/schema/runnable/utils.py +++ b/libs/langchain/langchain/schema/runnable/utils.py @@ -5,6 +5,7 @@ import asyncio import inspect import textwrap from inspect import signature +from itertools import groupby from typing import ( Any, AsyncIterable, @@ -13,8 +14,10 @@ from typing import ( Dict, Iterable, List, + NamedTuple, Optional, Protocol, + Sequence, Set, TypeVar, Union, @@ -211,3 +214,39 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: else: final = final + chunk return final + + +class ConfigurableField(NamedTuple): + id: str + name: Optional[str] = None + description: Optional[str] = None + annotation: Optional[Any] = None + + +class ConfigurableFieldSpec(NamedTuple): + id: str + name: Optional[str] + description: Optional[str] + + default: Any + annotation: Any + + +def get_unique_config_specs( + specs: Iterable[ConfigurableFieldSpec], +) -> Sequence[ConfigurableFieldSpec]: + grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id) + unique: List[ConfigurableFieldSpec] = [] + for id, dupes in grouped: + first = next(dupes) + others = list(dupes) + if len(others) == 0: + unique.append(first) + elif all(o == first for o in others): + unique.append(first) + else: + raise ValueError( + "RunnableSequence contains conflicting config specs" + f"for {id}: {[first] + others}" + ) + return unique diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 2031ec21cba..c389b426482 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -30,6 +30,7 @@ from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM from langchain.load.dump import dumpd, dumps from langchain.output_parsers.list import CommaSeparatedListOutputParser from langchain.prompts import PromptTemplate +from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ( ChatPromptTemplate, ChatPromptValue, @@ -56,7 +57,7 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) -from langchain.schema.runnable.base import RunnableGenerator +from langchain.schema.runnable.base import ConfigurableField, RunnableGenerator from langchain.schema.runnable.utils import add from langchain.tools.base import BaseTool, tool from langchain.tools.json.tool import JsonListKeysTool, JsonSpec @@ -143,6 +144,15 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "title": "FakeRunnableOutput", "type": "integer", } + assert fake.config_schema(include=["tags", "metadata", "run_name"]).schema() == { + "title": "FakeRunnableConfig", + "type": "object", + "properties": { + "metadata": {"title": "Metadata", "type": "object"}, + "run_name": {"title": "Run Name", "type": "string"}, + "tags": {"items": {"type": "string"}, "title": "Tags", "type": "array"}, + }, + } fake_bound = FakeRunnable().bind(a="b") # str -> int @@ -538,6 +548,261 @@ def test_schema_chains() -> None: } +def test_configurable_fields() -> None: + fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]] + + assert fake_llm.invoke("...") == "a" + + fake_llm_configurable = fake_llm.configurable_fields( + responses=ConfigurableField( + id="llm_responses", + name="LLM Responses", + description="A list of fake responses for this LLM", + ) + ) + + assert fake_llm_configurable.invoke("...") == "a" + + assert fake_llm_configurable.config_schema().schema() == { + "title": "RunnableConfigurableFieldsConfig", + "type": "object", + "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, + "definitions": { + "Configurable": { + "title": "Configurable", + "type": "object", + "properties": { + "llm_responses": { + "title": "LLM Responses", + "description": "A list of fake responses for this LLM", + "default": ["a"], + "type": "array", + "items": {"type": "string"}, + } + }, + } + }, + } + + fake_llm_configured = fake_llm_configurable.with_config( + configurable={"llm_responses": ["b"]} + ) + + assert fake_llm_configured.invoke("...") == "b" + + prompt = PromptTemplate.from_template("Hello, {name}!") + + assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!") + + prompt_configurable = prompt.configurable_fields( + template=ConfigurableField( + id="prompt_template", + name="Prompt Template", + description="The prompt template for this chain", + ) + ) + + assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue( + text="Hello, John!" + ) + + assert prompt_configurable.config_schema().schema() == { + "title": "RunnableConfigurableFieldsConfig", + "type": "object", + "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, + "definitions": { + "Configurable": { + "title": "Configurable", + "type": "object", + "properties": { + "prompt_template": { + "title": "Prompt Template", + "description": "The prompt template for this chain", + "default": "Hello, {name}!", + "type": "string", + } + }, + } + }, + } + + prompt_configured = prompt_configurable.with_config( + configurable={"prompt_template": "Hello, {name}! {name}!"} + ) + + assert prompt_configured.invoke({"name": "John"}) == StringPromptValue( + text="Hello, John! John!" + ) + + chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser() + + assert chain_configurable.invoke({"name": "John"}) == "a" + + assert chain_configurable.config_schema().schema() == { + "title": "RunnableSequenceConfig", + "type": "object", + "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, + "definitions": { + "Configurable": { + "title": "Configurable", + "type": "object", + "properties": { + "llm_responses": { + "title": "LLM Responses", + "description": "A list of fake responses for this LLM", + "default": ["a"], + "type": "array", + "items": {"type": "string"}, + }, + "prompt_template": { + "title": "Prompt Template", + "description": "The prompt template for this chain", + "default": "Hello, {name}!", + "type": "string", + }, + }, + } + }, + } + + assert ( + chain_configurable.with_config( + configurable={ + "prompt_template": "A very good morning to you, {name}!", + "llm_responses": ["c"], + } + ).invoke({"name": "John"}) + == "c" + ) + + chain_with_map_configurable: Runnable = prompt_configurable | { + "llm1": fake_llm_configurable | StrOutputParser(), + "llm2": fake_llm_configurable | StrOutputParser(), + "llm3": fake_llm.configurable_fields( + responses=ConfigurableField("other_responses") + ) + | StrOutputParser(), + } + + assert chain_with_map_configurable.invoke({"name": "John"}) == { + "llm1": "a", + "llm2": "a", + "llm3": "a", + } + + assert chain_with_map_configurable.config_schema().schema() == { + "title": "RunnableSequenceConfig", + "type": "object", + "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, + "definitions": { + "Configurable": { + "title": "Configurable", + "type": "object", + "properties": { + "llm_responses": { + "title": "LLM Responses", + "description": "A list of fake responses for this LLM", + "default": ["a"], + "type": "array", + "items": {"type": "string"}, + }, + "other_responses": { + "title": "Other Responses", + "default": ["a"], + "type": "array", + "items": {"type": "string"}, + }, + "prompt_template": { + "title": "Prompt Template", + "description": "The prompt template for this chain", + "default": "Hello, {name}!", + "type": "string", + }, + }, + } + }, + } + + assert chain_with_map_configurable.with_config( + configurable={ + "prompt_template": "A very good morning to you, {name}!", + "llm_responses": ["c"], + "other_responses": ["d"], + } + ).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"} + + +def test_configurable_fields_example() -> None: + fake_llm = ( + FakeListLLM(responses=["a"]) + .configurable_fields( + responses=ConfigurableField( + id="llm_responses", + name="LLM Responses", + description="A list of fake responses for this LLM", + ) + ) + .configurable_alternatives( + ConfigurableField(id="llm", name="LLM"), + chat=FakeListChatModel(responses=["b"]) | StrOutputParser(), + ) + ) + + prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields( + template=ConfigurableField( + id="prompt_template", + name="Prompt Template", + description="The prompt template for this chain", + ) + ) + + chain_configurable = prompt | fake_llm + + assert chain_configurable.invoke({"name": "John"}) == "a" + + assert chain_configurable.config_schema().schema() == { + "title": "RunnableSequenceConfig", + "type": "object", + "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, + "definitions": { + "Configurable": { + "title": "Configurable", + "type": "object", + "properties": { + "llm": { + "title": "LLM", + "default": "default", + "anyOf": [ + {"enum": ["chat"], "type": "string"}, + {"enum": ["default"], "type": "string"}, + ], + }, + "llm_responses": { + "title": "LLM Responses", + "description": "A list of fake responses for this LLM", + "default": ["a"], + "type": "array", + "items": {"type": "string"}, + }, + "prompt_template": { + "title": "Prompt Template", + "description": "The prompt template for this chain", + "default": "Hello, {name}!", + "type": "string", + }, + }, + } + }, + } + + assert ( + chain_configurable.with_config(configurable={"llm": "chat"}).invoke( + {"name": "John"} + ) + == "b" + ) + + @pytest.mark.asyncio async def test_with_config(mocker: MockerFixture) -> None: fake = FakeRunnable()