diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 90ef56fac14..65da79c489b 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -30,10 +30,16 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config from langchain.schema.runnable.fallbacks import RunnableWithFallbacks from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable -from langchain.schema.runnable.utils import ConfigurableField +from langchain.schema.runnable.utils import ( + ConfigurableField, + ConfigurableFieldMultiOption, + ConfigurableFieldSingleOption, +) __all__ = [ "ConfigurableField", + "ConfigurableFieldSingleOption", + "ConfigurableFieldMultiOption", "GetLocalVar", "patch_config", "PutLocalVar", diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 8a429f62e68..612e19d2c70 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -58,6 +58,7 @@ from langchain.schema.runnable.config import ( ) from langchain.schema.runnable.utils import ( AddableDict, + AnyConfigurableField, ConfigurableField, ConfigurableFieldSpec, Input, @@ -975,7 +976,7 @@ class Runnable(Generic[Input, Output], ABC): class RunnableSerializable(Serializable, Runnable[Input, Output]): def configurable_fields( - self, **kwargs: ConfigurableField + self, **kwargs: AnyConfigurableField ) -> RunnableSerializable[Input, Output]: from langchain.schema.runnable.configurable import RunnableConfigurableFields diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index 7933455eaa7..7c246e6e01c 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -23,7 +23,10 @@ from langchain.schema.runnable.config import ( get_executor_for_config, ) from langchain.schema.runnable.utils import ( + AnyConfigurableField, ConfigurableField, + ConfigurableFieldMultiOption, + ConfigurableFieldSingleOption, ConfigurableFieldSpec, Input, Output, @@ -193,7 +196,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): class RunnableConfigurableFields(DynamicRunnable[Input, Output]): - fields: Dict[str, ConfigurableField] + fields: Dict[str, AnyConfigurableField] @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: @@ -207,11 +210,15 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): or self.default.__fields__[field_name].annotation, default=getattr(self.default, field_name), ) + if isinstance(spec, ConfigurableField) + else make_options_spec( + spec, self.default.__fields__[field_name].field_info.description + ) for field_name, spec in self.fields.items() ] def configurable_fields( - self, **kwargs: ConfigurableField + self, **kwargs: AnyConfigurableField ) -> RunnableSerializable[Input, Output]: return self.default.configurable_fields(**{**self.fields, **kwargs}) @@ -220,10 +227,28 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): ) -> Runnable[Input, Output]: config = config or {} specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} - configurable = { + configurable_fields = { specs_by_id[k][0]: v for k, v in config.get("configurable", {}).items() - if k in specs_by_id + if k in specs_by_id and isinstance(specs_by_id[k][1], ConfigurableField) + } + configurable_single_options = { + k: v.options[(config.get("configurable", {}).get(v.id) or v.default)] + for k, v in self.fields.items() + if isinstance(v, ConfigurableFieldSingleOption) + } + configurable_multi_options = { + k: [ + v.options[o] + for o in config.get("configurable", {}).get(v.id, v.default) + ] + for k, v in self.fields.items() + if isinstance(v, ConfigurableFieldMultiOption) + } + configurable = { + **configurable_fields, + **configurable_single_options, + **configurable_multi_options, } if configurable: @@ -262,7 +287,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): ] + [s for alt in self.alternatives.values() for s in alt.config_specs] def configurable_fields( - self, **kwargs: ConfigurableField + self, **kwargs: AnyConfigurableField ) -> RunnableSerializable[Input, Output]: return self.__class__( which=self.which, @@ -281,3 +306,29 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): return self.alternatives[which] else: raise ValueError(f"Unknown alternative: {which}") + + +def make_options_spec( + spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption], + description: Optional[str], +) -> ConfigurableFieldSpec: + enum = StrEnum( # type: ignore[call-overload] + spec.name or spec.id, + ((v, v) for v in list(spec.options.keys())), + ) + if isinstance(spec, ConfigurableFieldSingleOption): + return ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description or description, + annotation=enum, + default=spec.default, + ) + else: + return ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description or description, + annotation=Sequence[enum], # type: ignore[valid-type] + default=spec.default, + ) diff --git a/libs/langchain/langchain/schema/runnable/utils.py b/libs/langchain/langchain/schema/runnable/utils.py index 8be6d0756e1..1d62b50dc89 100644 --- a/libs/langchain/langchain/schema/runnable/utils.py +++ b/libs/langchain/langchain/schema/runnable/utils.py @@ -14,6 +14,7 @@ from typing import ( Dict, Iterable, List, + Mapping, NamedTuple, Optional, Protocol, @@ -218,11 +219,35 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: class ConfigurableField(NamedTuple): id: str + name: Optional[str] = None description: Optional[str] = None annotation: Optional[Any] = None +class ConfigurableFieldSingleOption(NamedTuple): + id: str + options: Mapping[str, Any] + default: str + + name: Optional[str] = None + description: Optional[str] = None + + +class ConfigurableFieldMultiOption(NamedTuple): + id: str + options: Mapping[str, Any] + default: Sequence[str] + + name: Optional[str] = None + description: Optional[str] = None + + +AnyConfigurableField = Union[ + ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption +] + + class ConfigurableFieldSpec(NamedTuple): id: str name: Optional[str] diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index f9ebb98245a..45cc12ba955 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -59,7 +59,11 @@ from langchain.schema.runnable import ( RunnableWithFallbacks, ) from langchain.schema.runnable.base import ConfigurableField, RunnableGenerator -from langchain.schema.runnable.utils import add +from langchain.schema.runnable.utils import ( + ConfigurableFieldMultiOption, + ConfigurableFieldSingleOption, + add, +) from langchain.tools.base import BaseTool, tool from langchain.tools.json.tool import JsonListKeysTool, JsonSpec @@ -903,6 +907,18 @@ def test_configurable_fields() -> None: def test_configurable_fields_example() -> None: + fake_chat = FakeListChatModel(responses=["b"]).configurable_fields( + responses=ConfigurableFieldMultiOption( + id="chat_responses", + name="Chat Responses", + options={ + "hello": "A good morning to you!", + "bye": "See you later!", + "helpful": "How can I help you?", + }, + default=["hello", "bye"], + ) + ) fake_llm = ( FakeListLLM(responses=["a"]) .configurable_fields( @@ -914,15 +930,20 @@ def test_configurable_fields_example() -> None: ) .configurable_alternatives( ConfigurableField(id="llm", name="LLM"), - chat=FakeListChatModel(responses=["b"]) | StrOutputParser(), + chat=fake_chat | StrOutputParser(), ) ) prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields( - template=ConfigurableField( + template=ConfigurableFieldSingleOption( id="prompt_template", name="Prompt Template", description="The prompt template for this chain", + options={ + "hello": "Hello, {name}!", + "good_morning": "A very good morning to you, {name}!", + }, + default="hello", ) ) @@ -941,10 +962,28 @@ def test_configurable_fields_example() -> None: "enum": ["chat", "default"], "type": "string", }, + "Chat_Responses": { + "description": "An enumeration.", + "enum": ["hello", "bye", "helpful"], + "title": "Chat Responses", + "type": "string", + }, + "Prompt_Template": { + "description": "An enumeration.", + "enum": ["hello", "good_morning"], + "title": "Prompt Template", + "type": "string", + }, "Configurable": { "title": "Configurable", "type": "object", "properties": { + "chat_responses": { + "default": ["hello", "bye"], + "items": {"$ref": "#/definitions/Chat_Responses"}, + "title": "Chat Responses", + "type": "array", + }, "llm": { "title": "LLM", "default": "default", @@ -960,8 +999,8 @@ def test_configurable_fields_example() -> None: "prompt_template": { "title": "Prompt Template", "description": "The prompt template for this chain", - "default": "Hello, {name}!", - "type": "string", + "default": "hello", + "allOf": [{"$ref": "#/definitions/Prompt_Template"}], }, }, }, @@ -972,7 +1011,14 @@ def test_configurable_fields_example() -> None: chain_configurable.with_config(configurable={"llm": "chat"}).invoke( {"name": "John"} ) - == "b" + == "A good morning to you!" + ) + + assert ( + chain_configurable.with_config( + configurable={"llm": "chat", "chat_responses": ["helpful"]} + ).invoke({"name": "John"}) + == "How can I help you?" )