mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-22 20:59:05 +00:00
Add option to prefix config keys in configurable_alts (#13714)
This commit is contained in:
@@ -13,7 +13,7 @@ tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -x tests/unit_tests
|
||||
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests
|
||||
|
||||
|
||||
######################
|
||||
|
||||
@@ -1204,7 +1204,9 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
def configurable_alternatives(
|
||||
self,
|
||||
which: ConfigurableField,
|
||||
*,
|
||||
default_key: str = "default",
|
||||
prefix_keys: bool = False,
|
||||
**kwargs: Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
from langchain_core.runnables.configurable import (
|
||||
@@ -1212,7 +1214,11 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
return RunnableConfigurableAlternatives(
|
||||
which=which, default=self, alternatives=kwargs, default_key=default_key
|
||||
which=which,
|
||||
default=self,
|
||||
alternatives=kwargs,
|
||||
default_key=default_key,
|
||||
prefix_keys=prefix_keys,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -220,6 +220,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
annotation=spec.annotation
|
||||
or self.default.__fields__[field_name].annotation,
|
||||
default=getattr(self.default, field_name),
|
||||
is_shared=spec.is_shared,
|
||||
)
|
||||
if isinstance(spec, ConfigurableField)
|
||||
else make_options_spec(
|
||||
@@ -298,6 +299,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
]
|
||||
|
||||
default_key: str = "default"
|
||||
"""The enum value to use for the default option. Defaults to "default"."""
|
||||
|
||||
prefix_keys: bool
|
||||
"""Whether to prefix configurable fields of each alternative with a namespace
|
||||
of the form <which.id>==<alternative_key>, eg. a key named "temperature" used by
|
||||
the alternative named "gpt3" becomes "model==gpt3/temperature"."""
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
@@ -313,21 +320,37 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
),
|
||||
)
|
||||
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
|
||||
return [
|
||||
ConfigurableFieldSpec(
|
||||
id=self.which.id,
|
||||
name=self.which.name,
|
||||
description=self.which.description,
|
||||
annotation=which_enum,
|
||||
default=self.default_key,
|
||||
),
|
||||
*self.default.config_specs,
|
||||
] + [
|
||||
s
|
||||
for alt in self.alternatives.values()
|
||||
if isinstance(alt, RunnableSerializable)
|
||||
for s in alt.config_specs
|
||||
]
|
||||
return get_unique_config_specs(
|
||||
# which alternative
|
||||
[
|
||||
ConfigurableFieldSpec(
|
||||
id=self.which.id,
|
||||
name=self.which.name,
|
||||
description=self.which.description,
|
||||
annotation=which_enum,
|
||||
default=self.default_key,
|
||||
is_shared=self.which.is_shared,
|
||||
),
|
||||
]
|
||||
# config specs of the default option
|
||||
+ (
|
||||
[
|
||||
prefix_config_spec(s, f"{self.which.id}=={self.default_key}")
|
||||
for s in self.default.config_specs
|
||||
]
|
||||
if self.prefix_keys
|
||||
else self.default.config_specs
|
||||
)
|
||||
# config specs of the alternatives
|
||||
+ [
|
||||
prefix_config_spec(s, f"{self.which.id}=={alt_key}")
|
||||
if self.prefix_keys
|
||||
else s
|
||||
for alt_key, alt in self.alternatives.items()
|
||||
if isinstance(alt, RunnableSerializable)
|
||||
for s in alt.config_specs
|
||||
]
|
||||
)
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
@@ -355,6 +378,23 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
raise ValueError(f"Unknown alternative: {which}")
|
||||
|
||||
|
||||
def prefix_config_spec(
|
||||
spec: ConfigurableFieldSpec, prefix: str
|
||||
) -> ConfigurableFieldSpec:
|
||||
return (
|
||||
ConfigurableFieldSpec(
|
||||
id=f"{prefix}/{spec.id}",
|
||||
name=spec.name,
|
||||
description=spec.description,
|
||||
annotation=spec.annotation,
|
||||
default=spec.default,
|
||||
is_shared=spec.is_shared,
|
||||
)
|
||||
if not spec.is_shared
|
||||
else spec
|
||||
)
|
||||
|
||||
|
||||
def make_options_spec(
|
||||
spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption],
|
||||
description: Optional[str],
|
||||
@@ -377,6 +417,7 @@ def make_options_spec(
|
||||
description=spec.description or description,
|
||||
annotation=enum,
|
||||
default=spec.default,
|
||||
is_shared=spec.is_shared,
|
||||
)
|
||||
else:
|
||||
return ConfigurableFieldSpec(
|
||||
@@ -385,4 +426,5 @@ def make_options_spec(
|
||||
description=spec.description or description,
|
||||
annotation=Sequence[enum], # type: ignore[valid-type]
|
||||
default=spec.default,
|
||||
is_shared=spec.is_shared,
|
||||
)
|
||||
|
||||
@@ -169,6 +169,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
name="Session ID",
|
||||
description="Unique identifier for a session.",
|
||||
default="",
|
||||
is_shared=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -257,6 +257,7 @@ class ConfigurableField(NamedTuple):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
annotation: Optional[Any] = None
|
||||
is_shared: bool = False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.id, self.annotation))
|
||||
@@ -271,6 +272,7 @@ class ConfigurableFieldSingleOption(NamedTuple):
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_shared: bool = False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.id, tuple(self.options.keys()), self.default))
|
||||
@@ -285,6 +287,7 @@ class ConfigurableFieldMultiOption(NamedTuple):
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_shared: bool = False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.id, tuple(self.options.keys()), tuple(self.default)))
|
||||
@@ -299,12 +302,13 @@ class ConfigurableFieldSpec(NamedTuple):
|
||||
"""A field that can be configured by the user. It is a specification of a field."""
|
||||
|
||||
id: str
|
||||
name: Optional[str]
|
||||
description: Optional[str]
|
||||
|
||||
default: Any
|
||||
annotation: Any
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
default: Any = None
|
||||
is_shared: bool = False
|
||||
|
||||
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
|
||||
@@ -1020,6 +1020,118 @@ def test_configurable_alts_factory() -> None:
|
||||
assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"
|
||||
|
||||
|
||||
def test_configurable_fields_prefix_keys() -> None:
|
||||
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
|
||||
responses=ConfigurableFieldMultiOption(
|
||||
id="responses",
|
||||
name="Chat Responses",
|
||||
options={
|
||||
"hello": "A good morning to you!",
|
||||
"bye": "See you later!",
|
||||
"helpful": "How can I help you?",
|
||||
},
|
||||
default=["hello", "bye"],
|
||||
),
|
||||
# (sleep is a configurable field in FakeListChatModel)
|
||||
sleep=ConfigurableField(
|
||||
id="chat_sleep",
|
||||
is_shared=True,
|
||||
),
|
||||
)
|
||||
fake_llm = (
|
||||
FakeListLLM(responses=["a"])
|
||||
.configurable_fields(
|
||||
responses=ConfigurableField(
|
||||
id="responses",
|
||||
name="LLM Responses",
|
||||
description="A list of fake responses for this LLM",
|
||||
)
|
||||
)
|
||||
.configurable_alternatives(
|
||||
ConfigurableField(id="llm", name="LLM"),
|
||||
chat=fake_chat | StrOutputParser(),
|
||||
prefix_keys=True,
|
||||
)
|
||||
)
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
|
||||
template=ConfigurableFieldSingleOption(
|
||||
id="prompt_template",
|
||||
name="Prompt Template",
|
||||
description="The prompt template for this chain",
|
||||
options={
|
||||
"hello": "Hello, {name}!",
|
||||
"good_morning": "A very good morning to you, {name}!",
|
||||
},
|
||||
default="hello",
|
||||
)
|
||||
)
|
||||
|
||||
chain = prompt | fake_llm
|
||||
|
||||
assert chain.config_schema().schema() == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"LLM": {
|
||||
"title": "LLM",
|
||||
"description": "An enumeration.",
|
||||
"enum": ["chat", "default"],
|
||||
"type": "string",
|
||||
},
|
||||
"Chat_Responses": {
|
||||
"title": "Chat Responses",
|
||||
"description": "An enumeration.",
|
||||
"enum": ["hello", "bye", "helpful"],
|
||||
"type": "string",
|
||||
},
|
||||
"Prompt_Template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "An enumeration.",
|
||||
"enum": ["hello", "good_morning"],
|
||||
"type": "string",
|
||||
},
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt_template": {
|
||||
"title": "Prompt Template",
|
||||
"description": "The prompt template for this chain",
|
||||
"default": "hello",
|
||||
"allOf": [{"$ref": "#/definitions/Prompt_Template"}],
|
||||
},
|
||||
"llm": {
|
||||
"title": "LLM",
|
||||
"default": "default",
|
||||
"allOf": [{"$ref": "#/definitions/LLM"}],
|
||||
},
|
||||
# not prefixed because marked as shared
|
||||
"chat_sleep": {
|
||||
"title": "Chat Sleep",
|
||||
"type": "number",
|
||||
},
|
||||
# prefixed for "chat" option
|
||||
"llm==chat/responses": {
|
||||
"title": "Chat Responses",
|
||||
"default": ["hello", "bye"],
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/Chat_Responses"},
|
||||
},
|
||||
# prefixed for "default" option
|
||||
"llm==default/responses": {
|
||||
"title": "LLM Responses",
|
||||
"description": "A list of fake responses for this LLM",
|
||||
"default": ["a"],
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_configurable_fields_example() -> None:
|
||||
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
|
||||
responses=ConfigurableFieldMultiOption(
|
||||
|
||||
Reference in New Issue
Block a user