mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-24 16:37:46 +00:00
Use an enum for configurable_alternatives to make the generated json schema nicer (#11350)
This commit is contained in:
parent
b499de2926
commit
b0893c7c6a
@ -129,9 +129,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return []
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
|
||||
class _Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@ -150,7 +148,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
for spec in config_specs
|
||||
},
|
||||
)
|
||||
if config_specs
|
||||
if config_specs and "configurable" in include
|
||||
else None
|
||||
)
|
||||
|
||||
@ -161,7 +159,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
**{
|
||||
field_name: (field_type, None)
|
||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||
if field_name in include
|
||||
if field_name in [i for i in include if i != "configurable"]
|
||||
},
|
||||
)
|
||||
|
||||
@ -873,7 +871,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
"available keys are {self.__fields__.keys()}"
|
||||
)
|
||||
|
||||
return RunnableConfigurableFields(bound=self, fields=kwargs)
|
||||
return RunnableConfigurableFields(default=self, fields=kwargs)
|
||||
|
||||
def configurable_alternatives(
|
||||
self,
|
||||
@ -885,7 +883,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
return RunnableConfigurableAlternatives(
|
||||
which=which, bound=self, alternatives=kwargs
|
||||
which=which, default=self, alternatives=kwargs
|
||||
)
|
||||
|
||||
|
||||
@ -2051,9 +2049,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
|
||||
return self.bound.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
@ -2132,9 +2128,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
|
||||
return self.bound.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
@ -7,7 +8,6 @@ from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
@ -32,7 +32,7 @@ from langchain.schema.runnable.utils import (
|
||||
|
||||
|
||||
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
bound: RunnableSerializable[Input, Output]
|
||||
default: RunnableSerializable[Input, Output]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -47,19 +47,19 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.bound.InputType
|
||||
return self.default.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.bound.OutputType
|
||||
return self.default.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.input_schema
|
||||
return self.default.input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.output_schema
|
||||
return self.default.output_schema
|
||||
|
||||
@abstractmethod
|
||||
def _prepare(
|
||||
@ -88,8 +88,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self._prepare(c) for c in configs]
|
||||
|
||||
if all(p is self.bound for p in prepared):
|
||||
return self.bound.batch(
|
||||
if all(p is self.default for p in prepared):
|
||||
return self.default.batch(
|
||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
@ -131,8 +131,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self._prepare(c) for c in configs]
|
||||
|
||||
if all(p is self.bound for p in prepared):
|
||||
return await self.bound.abatch(
|
||||
if all(p is self.default for p in prepared):
|
||||
return await self.default.abatch(
|
||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
@ -202,10 +202,10 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description
|
||||
or self.bound.__fields__[field_name].field_info.description,
|
||||
or self.default.__fields__[field_name].field_info.description,
|
||||
annotation=spec.annotation
|
||||
or self.bound.__fields__[field_name].annotation,
|
||||
default=getattr(self.bound, field_name),
|
||||
or self.default.__fields__[field_name].annotation,
|
||||
default=getattr(self.default, field_name),
|
||||
)
|
||||
for field_name, spec in self.fields.items()
|
||||
]
|
||||
@ -213,7 +213,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
def configurable_fields(
|
||||
self, **kwargs: ConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
return self.bound.configurable_fields(**{**self.fields, **kwargs})
|
||||
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
@ -227,9 +227,14 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
}
|
||||
|
||||
if configurable:
|
||||
return self.bound.__class__(**{**self.bound.dict(), **configurable})
|
||||
return self.default.__class__(**{**self.default.dict(), **configurable})
|
||||
else:
|
||||
return self.bound
|
||||
return self.default
|
||||
|
||||
|
||||
# Before Python 3.11 native StrEnum is not available
|
||||
class StrEnum(str, enum.Enum):
|
||||
pass
|
||||
|
||||
|
||||
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
@ -237,21 +242,23 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
|
||||
alternatives: Dict[str, RunnableSerializable[Input, Output]]
|
||||
|
||||
default_key: str = "default"
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
alt_keys = self.alternatives.keys()
|
||||
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore
|
||||
Literal["default"],
|
||||
which_enum = StrEnum( # type: ignore[call-overload]
|
||||
self.which.name or self.which.id,
|
||||
((v, v) for v in list(self.alternatives.keys()) + [self.default_key]),
|
||||
)
|
||||
return [
|
||||
ConfigurableFieldSpec(
|
||||
id=self.which.id,
|
||||
name=self.which.name,
|
||||
description=self.which.description,
|
||||
annotation=Union[which_keys], # type: ignore
|
||||
default="default",
|
||||
annotation=which_enum,
|
||||
default=self.default_key,
|
||||
),
|
||||
*self.bound.config_specs,
|
||||
*self.default.config_specs,
|
||||
] + [s for alt in self.alternatives.values() for s in alt.config_specs]
|
||||
|
||||
def configurable_fields(
|
||||
@ -259,7 +266,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
return self.__class__(
|
||||
which=self.which,
|
||||
bound=self.bound.configurable_fields(**kwargs),
|
||||
default=self.default.configurable_fields(**kwargs),
|
||||
alternatives=self.alternatives,
|
||||
)
|
||||
|
||||
@ -267,9 +274,9 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
config = config or {}
|
||||
which = config.get("configurable", {}).get(self.which.id)
|
||||
if not which:
|
||||
return self.bound
|
||||
which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
|
||||
if which == self.default_key:
|
||||
return self.default
|
||||
elif which in self.alternatives:
|
||||
return self.alternatives[which]
|
||||
else:
|
||||
|
@ -69,9 +69,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
for spec in step.config_specs
|
||||
)
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
|
||||
return self.runnable.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
|
@ -563,7 +563,7 @@ def test_configurable_fields() -> None:
|
||||
|
||||
assert fake_llm_configurable.invoke("...") == "a"
|
||||
|
||||
assert fake_llm_configurable.config_schema().schema() == {
|
||||
assert fake_llm_configurable.config_schema(include=["configurable"]).schema() == {
|
||||
"title": "RunnableConfigurableFieldsConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -606,7 +606,7 @@ def test_configurable_fields() -> None:
|
||||
text="Hello, John!"
|
||||
)
|
||||
|
||||
assert prompt_configurable.config_schema().schema() == {
|
||||
assert prompt_configurable.config_schema(include=["configurable"]).schema() == {
|
||||
"title": "RunnableConfigurableFieldsConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -638,7 +638,7 @@ def test_configurable_fields() -> None:
|
||||
|
||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||
|
||||
assert chain_configurable.config_schema().schema() == {
|
||||
assert chain_configurable.config_schema(include=["configurable"]).schema() == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -690,7 +690,9 @@ def test_configurable_fields() -> None:
|
||||
"llm3": "a",
|
||||
}
|
||||
|
||||
assert chain_with_map_configurable.config_schema().schema() == {
|
||||
assert chain_with_map_configurable.config_schema(
|
||||
include=["configurable"]
|
||||
).schema() == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -760,11 +762,17 @@ def test_configurable_fields_example() -> None:
|
||||
|
||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||
|
||||
assert chain_configurable.config_schema().schema() == {
|
||||
assert chain_configurable.config_schema(include=["configurable"]).schema() == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
"definitions": {
|
||||
"LLM": {
|
||||
"title": "LLM",
|
||||
"description": "An enumeration.",
|
||||
"enum": ["chat", "default"],
|
||||
"type": "string",
|
||||
},
|
||||
"Configurable": {
|
||||
"title": "Configurable",
|
||||
"type": "object",
|
||||
@ -772,10 +780,7 @@ def test_configurable_fields_example() -> None:
|
||||
"llm": {
|
||||
"title": "LLM",
|
||||
"default": "default",
|
||||
"anyOf": [
|
||||
{"enum": ["chat"], "type": "string"},
|
||||
{"enum": ["default"], "type": "string"},
|
||||
],
|
||||
"allOf": [{"$ref": "#/definitions/LLM"}],
|
||||
},
|
||||
"llm_responses": {
|
||||
"title": "LLM Responses",
|
||||
@ -791,7 +796,7 @@ def test_configurable_fields_example() -> None:
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user