Add configurable fields with options (#11601)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-10-10 22:17:22 +01:00 committed by GitHub
parent 0ca8d4449c
commit 9a0ed75a95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 142 additions and 13 deletions

View File

@ -30,10 +30,16 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable from langchain.schema.runnable.router import RouterInput, RouterRunnable
from langchain.schema.runnable.utils import ConfigurableField from langchain.schema.runnable.utils import (
ConfigurableField,
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
)
__all__ = [ __all__ = [
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"GetLocalVar", "GetLocalVar",
"patch_config", "patch_config",
"PutLocalVar", "PutLocalVar",

View File

@ -58,6 +58,7 @@ from langchain.schema.runnable.config import (
) )
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
AddableDict, AddableDict,
AnyConfigurableField,
ConfigurableField, ConfigurableField,
ConfigurableFieldSpec, ConfigurableFieldSpec,
Input, Input,
@ -975,7 +976,7 @@ class Runnable(Generic[Input, Output], ABC):
class RunnableSerializable(Serializable, Runnable[Input, Output]): class RunnableSerializable(Serializable, Runnable[Input, Output]):
def configurable_fields( def configurable_fields(
self, **kwargs: ConfigurableField self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]: ) -> RunnableSerializable[Input, Output]:
from langchain.schema.runnable.configurable import RunnableConfigurableFields from langchain.schema.runnable.configurable import RunnableConfigurableFields

View File

@ -23,7 +23,10 @@ from langchain.schema.runnable.config import (
get_executor_for_config, get_executor_for_config,
) )
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
AnyConfigurableField,
ConfigurableField, ConfigurableField,
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
ConfigurableFieldSpec, ConfigurableFieldSpec,
Input, Input,
Output, Output,
@ -193,7 +196,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
class RunnableConfigurableFields(DynamicRunnable[Input, Output]): class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
fields: Dict[str, ConfigurableField] fields: Dict[str, AnyConfigurableField]
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
@ -207,11 +210,15 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
or self.default.__fields__[field_name].annotation, or self.default.__fields__[field_name].annotation,
default=getattr(self.default, field_name), default=getattr(self.default, field_name),
) )
if isinstance(spec, ConfigurableField)
else make_options_spec(
spec, self.default.__fields__[field_name].field_info.description
)
for field_name, spec in self.fields.items() for field_name, spec in self.fields.items()
] ]
def configurable_fields( def configurable_fields(
self, **kwargs: ConfigurableField self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]: ) -> RunnableSerializable[Input, Output]:
return self.default.configurable_fields(**{**self.fields, **kwargs}) return self.default.configurable_fields(**{**self.fields, **kwargs})
@ -220,10 +227,28 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
config = config or {} config = config or {}
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
configurable = { configurable_fields = {
specs_by_id[k][0]: v specs_by_id[k][0]: v
for k, v in config.get("configurable", {}).items() for k, v in config.get("configurable", {}).items()
if k in specs_by_id if k in specs_by_id and isinstance(specs_by_id[k][1], ConfigurableField)
}
configurable_single_options = {
k: v.options[(config.get("configurable", {}).get(v.id) or v.default)]
for k, v in self.fields.items()
if isinstance(v, ConfigurableFieldSingleOption)
}
configurable_multi_options = {
k: [
v.options[o]
for o in config.get("configurable", {}).get(v.id, v.default)
]
for k, v in self.fields.items()
if isinstance(v, ConfigurableFieldMultiOption)
}
configurable = {
**configurable_fields,
**configurable_single_options,
**configurable_multi_options,
} }
if configurable: if configurable:
@ -262,7 +287,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
] + [s for alt in self.alternatives.values() for s in alt.config_specs] ] + [s for alt in self.alternatives.values() for s in alt.config_specs]
def configurable_fields( def configurable_fields(
self, **kwargs: ConfigurableField self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]: ) -> RunnableSerializable[Input, Output]:
return self.__class__( return self.__class__(
which=self.which, which=self.which,
@ -281,3 +306,29 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
return self.alternatives[which] return self.alternatives[which]
else: else:
raise ValueError(f"Unknown alternative: {which}") raise ValueError(f"Unknown alternative: {which}")
def make_options_spec(
spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption],
description: Optional[str],
) -> ConfigurableFieldSpec:
enum = StrEnum( # type: ignore[call-overload]
spec.name or spec.id,
((v, v) for v in list(spec.options.keys())),
)
if isinstance(spec, ConfigurableFieldSingleOption):
return ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description or description,
annotation=enum,
default=spec.default,
)
else:
return ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description or description,
annotation=Sequence[enum], # type: ignore[valid-type]
default=spec.default,
)

View File

@ -14,6 +14,7 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Mapping,
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
@ -218,11 +219,35 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
class ConfigurableField(NamedTuple): class ConfigurableField(NamedTuple):
id: str id: str
name: Optional[str] = None name: Optional[str] = None
description: Optional[str] = None description: Optional[str] = None
annotation: Optional[Any] = None annotation: Optional[Any] = None
class ConfigurableFieldSingleOption(NamedTuple):
id: str
options: Mapping[str, Any]
default: str
name: Optional[str] = None
description: Optional[str] = None
class ConfigurableFieldMultiOption(NamedTuple):
id: str
options: Mapping[str, Any]
default: Sequence[str]
name: Optional[str] = None
description: Optional[str] = None
AnyConfigurableField = Union[
ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption
]
class ConfigurableFieldSpec(NamedTuple): class ConfigurableFieldSpec(NamedTuple):
id: str id: str
name: Optional[str] name: Optional[str]

View File

@ -59,7 +59,11 @@ from langchain.schema.runnable import (
RunnableWithFallbacks, RunnableWithFallbacks,
) )
from langchain.schema.runnable.base import ConfigurableField, RunnableGenerator from langchain.schema.runnable.base import ConfigurableField, RunnableGenerator
from langchain.schema.runnable.utils import add from langchain.schema.runnable.utils import (
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
add,
)
from langchain.tools.base import BaseTool, tool from langchain.tools.base import BaseTool, tool
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
@ -903,6 +907,18 @@ def test_configurable_fields() -> None:
def test_configurable_fields_example() -> None: def test_configurable_fields_example() -> None:
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
responses=ConfigurableFieldMultiOption(
id="chat_responses",
name="Chat Responses",
options={
"hello": "A good morning to you!",
"bye": "See you later!",
"helpful": "How can I help you?",
},
default=["hello", "bye"],
)
)
fake_llm = ( fake_llm = (
FakeListLLM(responses=["a"]) FakeListLLM(responses=["a"])
.configurable_fields( .configurable_fields(
@ -914,15 +930,20 @@ def test_configurable_fields_example() -> None:
) )
.configurable_alternatives( .configurable_alternatives(
ConfigurableField(id="llm", name="LLM"), ConfigurableField(id="llm", name="LLM"),
chat=FakeListChatModel(responses=["b"]) | StrOutputParser(), chat=fake_chat | StrOutputParser(),
) )
) )
prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields( prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
template=ConfigurableField( template=ConfigurableFieldSingleOption(
id="prompt_template", id="prompt_template",
name="Prompt Template", name="Prompt Template",
description="The prompt template for this chain", description="The prompt template for this chain",
options={
"hello": "Hello, {name}!",
"good_morning": "A very good morning to you, {name}!",
},
default="hello",
) )
) )
@ -941,10 +962,28 @@ def test_configurable_fields_example() -> None:
"enum": ["chat", "default"], "enum": ["chat", "default"],
"type": "string", "type": "string",
}, },
"Chat_Responses": {
"description": "An enumeration.",
"enum": ["hello", "bye", "helpful"],
"title": "Chat Responses",
"type": "string",
},
"Prompt_Template": {
"description": "An enumeration.",
"enum": ["hello", "good_morning"],
"title": "Prompt Template",
"type": "string",
},
"Configurable": { "Configurable": {
"title": "Configurable", "title": "Configurable",
"type": "object", "type": "object",
"properties": { "properties": {
"chat_responses": {
"default": ["hello", "bye"],
"items": {"$ref": "#/definitions/Chat_Responses"},
"title": "Chat Responses",
"type": "array",
},
"llm": { "llm": {
"title": "LLM", "title": "LLM",
"default": "default", "default": "default",
@ -960,8 +999,8 @@ def test_configurable_fields_example() -> None:
"prompt_template": { "prompt_template": {
"title": "Prompt Template", "title": "Prompt Template",
"description": "The prompt template for this chain", "description": "The prompt template for this chain",
"default": "Hello, {name}!", "default": "hello",
"type": "string", "allOf": [{"$ref": "#/definitions/Prompt_Template"}],
}, },
}, },
}, },
@ -972,7 +1011,14 @@ def test_configurable_fields_example() -> None:
chain_configurable.with_config(configurable={"llm": "chat"}).invoke( chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
{"name": "John"} {"name": "John"}
) )
== "b" == "A good morning to you!"
)
assert (
chain_configurable.with_config(
configurable={"llm": "chat", "chat_responses": ["helpful"]}
).invoke({"name": "John"})
== "How can I help you?"
) )