mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 12:09:58 +00:00
Add configurable fields with options (#11601)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
0ca8d4449c
commit
9a0ed75a95
@ -30,10 +30,16 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
|
||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||
from langchain.schema.runnable.utils import ConfigurableField
|
||||
from langchain.schema.runnable.utils import (
|
||||
ConfigurableField,
|
||||
ConfigurableFieldMultiOption,
|
||||
ConfigurableFieldSingleOption,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"GetLocalVar",
|
||||
"patch_config",
|
||||
"PutLocalVar",
|
||||
|
@ -58,6 +58,7 @@ from langchain.schema.runnable.config import (
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
AddableDict,
|
||||
AnyConfigurableField,
|
||||
ConfigurableField,
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
@ -975,7 +976,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
def configurable_fields(
|
||||
self, **kwargs: ConfigurableField
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
from langchain.schema.runnable.configurable import RunnableConfigurableFields
|
||||
|
||||
|
@ -23,7 +23,10 @@ from langchain.schema.runnable.config import (
|
||||
get_executor_for_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
AnyConfigurableField,
|
||||
ConfigurableField,
|
||||
ConfigurableFieldMultiOption,
|
||||
ConfigurableFieldSingleOption,
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
@ -193,7 +196,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
|
||||
class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
fields: Dict[str, ConfigurableField]
|
||||
fields: Dict[str, AnyConfigurableField]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
@ -207,11 +210,15 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
or self.default.__fields__[field_name].annotation,
|
||||
default=getattr(self.default, field_name),
|
||||
)
|
||||
if isinstance(spec, ConfigurableField)
|
||||
else make_options_spec(
|
||||
spec, self.default.__fields__[field_name].field_info.description
|
||||
)
|
||||
for field_name, spec in self.fields.items()
|
||||
]
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: ConfigurableField
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
||||
|
||||
@ -220,10 +227,28 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
) -> Runnable[Input, Output]:
|
||||
config = config or {}
|
||||
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||
configurable = {
|
||||
configurable_fields = {
|
||||
specs_by_id[k][0]: v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
if k in specs_by_id
|
||||
if k in specs_by_id and isinstance(specs_by_id[k][1], ConfigurableField)
|
||||
}
|
||||
configurable_single_options = {
|
||||
k: v.options[(config.get("configurable", {}).get(v.id) or v.default)]
|
||||
for k, v in self.fields.items()
|
||||
if isinstance(v, ConfigurableFieldSingleOption)
|
||||
}
|
||||
configurable_multi_options = {
|
||||
k: [
|
||||
v.options[o]
|
||||
for o in config.get("configurable", {}).get(v.id, v.default)
|
||||
]
|
||||
for k, v in self.fields.items()
|
||||
if isinstance(v, ConfigurableFieldMultiOption)
|
||||
}
|
||||
configurable = {
|
||||
**configurable_fields,
|
||||
**configurable_single_options,
|
||||
**configurable_multi_options,
|
||||
}
|
||||
|
||||
if configurable:
|
||||
@ -262,7 +287,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
] + [s for alt in self.alternatives.values() for s in alt.config_specs]
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: ConfigurableField
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
return self.__class__(
|
||||
which=self.which,
|
||||
@ -281,3 +306,29 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
return self.alternatives[which]
|
||||
else:
|
||||
raise ValueError(f"Unknown alternative: {which}")
|
||||
|
||||
|
||||
def make_options_spec(
|
||||
spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption],
|
||||
description: Optional[str],
|
||||
) -> ConfigurableFieldSpec:
|
||||
enum = StrEnum( # type: ignore[call-overload]
|
||||
spec.name or spec.id,
|
||||
((v, v) for v in list(spec.options.keys())),
|
||||
)
|
||||
if isinstance(spec, ConfigurableFieldSingleOption):
|
||||
return ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description or description,
|
||||
annotation=enum,
|
||||
default=spec.default,
|
||||
)
|
||||
else:
|
||||
return ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description or description,
|
||||
annotation=Sequence[enum], # type: ignore[valid-type]
|
||||
default=spec.default,
|
||||
)
|
||||
|
@ -14,6 +14,7 @@ from typing import (
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
@ -218,11 +219,35 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
|
||||
class ConfigurableField(NamedTuple):
|
||||
id: str
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
annotation: Optional[Any] = None
|
||||
|
||||
|
||||
class ConfigurableFieldSingleOption(NamedTuple):
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
default: str
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ConfigurableFieldMultiOption(NamedTuple):
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
default: Sequence[str]
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
AnyConfigurableField = Union[
|
||||
ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption
|
||||
]
|
||||
|
||||
|
||||
class ConfigurableFieldSpec(NamedTuple):
|
||||
id: str
|
||||
name: Optional[str]
|
||||
|
@ -59,7 +59,11 @@ from langchain.schema.runnable import (
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain.schema.runnable.base import ConfigurableField, RunnableGenerator
|
||||
from langchain.schema.runnable.utils import add
|
||||
from langchain.schema.runnable.utils import (
|
||||
ConfigurableFieldMultiOption,
|
||||
ConfigurableFieldSingleOption,
|
||||
add,
|
||||
)
|
||||
from langchain.tools.base import BaseTool, tool
|
||||
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
|
||||
|
||||
@ -903,6 +907,18 @@ def test_configurable_fields() -> None:
|
||||
|
||||
|
||||
def test_configurable_fields_example() -> None:
|
||||
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
|
||||
responses=ConfigurableFieldMultiOption(
|
||||
id="chat_responses",
|
||||
name="Chat Responses",
|
||||
options={
|
||||
"hello": "A good morning to you!",
|
||||
"bye": "See you later!",
|
||||
"helpful": "How can I help you?",
|
||||
},
|
||||
default=["hello", "bye"],
|
||||
)
|
||||
)
|
||||
fake_llm = (
|
||||
FakeListLLM(responses=["a"])
|
||||
.configurable_fields(
|
||||
@ -914,15 +930,20 @@ def test_configurable_fields_example() -> None:
|
||||
)
|
||||
.configurable_alternatives(
|
||||
ConfigurableField(id="llm", name="LLM"),
|
||||
chat=FakeListChatModel(responses=["b"]) | StrOutputParser(),
|
||||
chat=fake_chat | StrOutputParser(),
|
||||
)
|
||||
)
|
||||
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
|
||||
template=ConfigurableField(
|
||||
template=ConfigurableFieldSingleOption(
|
||||
id="prompt_template",
|
||||
name="Prompt Template",
|
||||
description="The prompt template for this chain",
|
||||
options={
|
||||
"hello": "Hello, {name}!",
|
||||
"good_morning": "A very good morning to you, {name}!",
|
||||
},
|
||||
default="hello",
|
||||
)
|
||||
)
|
||||
|
||||
@ -941,10 +962,28 @@ def test_configurable_fields_example() -> None:
|
||||
"enum": ["chat", "default"],
|
||||
"type": "string",
|
||||
},
|
||||
"Chat_Responses": {
|
||||
"description": "An enumeration.",
|
||||
"enum": ["hello", "bye", "helpful"],
|
||||
"title": "Chat Responses",
|
||||
"type": "string",
|
||||
},
|
||||
"Prompt_Template": {
|
||||
"description": "An enumeration.",
|
||||
"enum": ["hello", "good_morning"],
|
||||
"title": "Prompt Template",
|
||||
"type": "string",
|
||||
},
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"chat_responses": {
|
||||
"default": ["hello", "bye"],
|
||||
"items": {"$ref": "#/definitions/Chat_Responses"},
|
||||
"title": "Chat Responses",
|
||||
"type": "array",
|
||||
},
|
||||
"llm": {
|
||||
"title": "LLM",
|
||||
"default": "default",
|
||||
@ -960,8 +999,8 @@ def test_configurable_fields_example() -> None:
|
||||
"prompt_template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "The prompt template for this chain",
|
||||
"default": "Hello, {name}!",
|
||||
"type": "string",
|
||||
"default": "hello",
|
||||
"allOf": [{"$ref": "#/definitions/Prompt_Template"}],
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -972,7 +1011,14 @@ def test_configurable_fields_example() -> None:
|
||||
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
|
||||
{"name": "John"}
|
||||
)
|
||||
== "b"
|
||||
== "A good morning to you!"
|
||||
)
|
||||
|
||||
assert (
|
||||
chain_configurable.with_config(
|
||||
configurable={"llm": "chat", "chat_responses": ["helpful"]}
|
||||
).invoke({"name": "John"})
|
||||
== "How can I help you?"
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user