mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-25 00:47:39 +00:00
Use an enum for configurable_alternatives to make the generated json schema nicer (#11350)
This commit is contained in:
parent
b499de2926
commit
b0893c7c6a
@ -129,9 +129,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def config_schema(
|
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
class _Config:
|
class _Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@ -150,7 +148,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
for spec in config_specs
|
for spec in config_specs
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if config_specs
|
if config_specs and "configurable" in include
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -161,7 +159,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
**{
|
**{
|
||||||
field_name: (field_type, None)
|
field_name: (field_type, None)
|
||||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||||
if field_name in include
|
if field_name in [i for i in include if i != "configurable"]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -873,7 +871,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
|||||||
"available keys are {self.__fields__.keys()}"
|
"available keys are {self.__fields__.keys()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return RunnableConfigurableFields(bound=self, fields=kwargs)
|
return RunnableConfigurableFields(default=self, fields=kwargs)
|
||||||
|
|
||||||
def configurable_alternatives(
|
def configurable_alternatives(
|
||||||
self,
|
self,
|
||||||
@ -885,7 +883,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return RunnableConfigurableAlternatives(
|
return RunnableConfigurableAlternatives(
|
||||||
which=which, bound=self, alternatives=kwargs
|
which=which, default=self, alternatives=kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -2051,9 +2049,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
return self.bound.config_specs
|
return self.bound.config_specs
|
||||||
|
|
||||||
def config_schema(
|
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
return self.bound.config_schema(include=include)
|
return self.bound.config_schema(include=include)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -2132,9 +2128,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
return self.bound.config_specs
|
return self.bound.config_specs
|
||||||
|
|
||||||
def config_schema(
|
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
return self.bound.config_schema(include=include)
|
return self.bound.config_schema(include=include)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -7,7 +8,6 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Literal,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Type,
|
Type,
|
||||||
@ -32,7 +32,7 @@ from langchain.schema.runnable.utils import (
|
|||||||
|
|
||||||
|
|
||||||
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||||
bound: RunnableSerializable[Input, Output]
|
default: RunnableSerializable[Input, Output]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -47,19 +47,19 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Type[Input]:
|
def InputType(self) -> Type[Input]:
|
||||||
return self.bound.InputType
|
return self.default.InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> Type[Output]:
|
def OutputType(self) -> Type[Output]:
|
||||||
return self.bound.OutputType
|
return self.default.OutputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
return self.bound.input_schema
|
return self.default.input_schema
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return self.bound.output_schema
|
return self.default.output_schema
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _prepare(
|
def _prepare(
|
||||||
@ -88,8 +88,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
prepared = [self._prepare(c) for c in configs]
|
prepared = [self._prepare(c) for c in configs]
|
||||||
|
|
||||||
if all(p is self.bound for p in prepared):
|
if all(p is self.default for p in prepared):
|
||||||
return self.bound.batch(
|
return self.default.batch(
|
||||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,8 +131,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
prepared = [self._prepare(c) for c in configs]
|
prepared = [self._prepare(c) for c in configs]
|
||||||
|
|
||||||
if all(p is self.bound for p in prepared):
|
if all(p is self.default for p in prepared):
|
||||||
return await self.bound.abatch(
|
return await self.default.abatch(
|
||||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -202,10 +202,10 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
id=spec.id,
|
id=spec.id,
|
||||||
name=spec.name,
|
name=spec.name,
|
||||||
description=spec.description
|
description=spec.description
|
||||||
or self.bound.__fields__[field_name].field_info.description,
|
or self.default.__fields__[field_name].field_info.description,
|
||||||
annotation=spec.annotation
|
annotation=spec.annotation
|
||||||
or self.bound.__fields__[field_name].annotation,
|
or self.default.__fields__[field_name].annotation,
|
||||||
default=getattr(self.bound, field_name),
|
default=getattr(self.default, field_name),
|
||||||
)
|
)
|
||||||
for field_name, spec in self.fields.items()
|
for field_name, spec in self.fields.items()
|
||||||
]
|
]
|
||||||
@ -213,7 +213,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
def configurable_fields(
|
def configurable_fields(
|
||||||
self, **kwargs: ConfigurableField
|
self, **kwargs: ConfigurableField
|
||||||
) -> RunnableSerializable[Input, Output]:
|
) -> RunnableSerializable[Input, Output]:
|
||||||
return self.bound.configurable_fields(**{**self.fields, **kwargs})
|
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
||||||
|
|
||||||
def _prepare(
|
def _prepare(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
@ -227,9 +227,14 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if configurable:
|
if configurable:
|
||||||
return self.bound.__class__(**{**self.bound.dict(), **configurable})
|
return self.default.__class__(**{**self.default.dict(), **configurable})
|
||||||
else:
|
else:
|
||||||
return self.bound
|
return self.default
|
||||||
|
|
||||||
|
|
||||||
|
# Before Python 3.11 native StrEnum is not available
|
||||||
|
class StrEnum(str, enum.Enum):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||||
@ -237,21 +242,23 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
|
|
||||||
alternatives: Dict[str, RunnableSerializable[Input, Output]]
|
alternatives: Dict[str, RunnableSerializable[Input, Output]]
|
||||||
|
|
||||||
|
default_key: str = "default"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
alt_keys = self.alternatives.keys()
|
which_enum = StrEnum( # type: ignore[call-overload]
|
||||||
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore
|
self.which.name or self.which.id,
|
||||||
Literal["default"],
|
((v, v) for v in list(self.alternatives.keys()) + [self.default_key]),
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
ConfigurableFieldSpec(
|
ConfigurableFieldSpec(
|
||||||
id=self.which.id,
|
id=self.which.id,
|
||||||
name=self.which.name,
|
name=self.which.name,
|
||||||
description=self.which.description,
|
description=self.which.description,
|
||||||
annotation=Union[which_keys], # type: ignore
|
annotation=which_enum,
|
||||||
default="default",
|
default=self.default_key,
|
||||||
),
|
),
|
||||||
*self.bound.config_specs,
|
*self.default.config_specs,
|
||||||
] + [s for alt in self.alternatives.values() for s in alt.config_specs]
|
] + [s for alt in self.alternatives.values() for s in alt.config_specs]
|
||||||
|
|
||||||
def configurable_fields(
|
def configurable_fields(
|
||||||
@ -259,7 +266,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
) -> RunnableSerializable[Input, Output]:
|
) -> RunnableSerializable[Input, Output]:
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
which=self.which,
|
which=self.which,
|
||||||
bound=self.bound.configurable_fields(**kwargs),
|
default=self.default.configurable_fields(**kwargs),
|
||||||
alternatives=self.alternatives,
|
alternatives=self.alternatives,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -267,9 +274,9 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Runnable[Input, Output]:
|
) -> Runnable[Input, Output]:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
which = config.get("configurable", {}).get(self.which.id)
|
which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
|
||||||
if not which:
|
if which == self.default_key:
|
||||||
return self.bound
|
return self.default
|
||||||
elif which in self.alternatives:
|
elif which in self.alternatives:
|
||||||
return self.alternatives[which]
|
return self.alternatives[which]
|
||||||
else:
|
else:
|
||||||
|
@ -69,9 +69,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
for spec in step.config_specs
|
for spec in step.config_specs
|
||||||
)
|
)
|
||||||
|
|
||||||
def config_schema(
|
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
return self.runnable.config_schema(include=include)
|
return self.runnable.config_schema(include=include)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -563,7 +563,7 @@ def test_configurable_fields() -> None:
|
|||||||
|
|
||||||
assert fake_llm_configurable.invoke("...") == "a"
|
assert fake_llm_configurable.invoke("...") == "a"
|
||||||
|
|
||||||
assert fake_llm_configurable.config_schema().schema() == {
|
assert fake_llm_configurable.config_schema(include=["configurable"]).schema() == {
|
||||||
"title": "RunnableConfigurableFieldsConfig",
|
"title": "RunnableConfigurableFieldsConfig",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
@ -606,7 +606,7 @@ def test_configurable_fields() -> None:
|
|||||||
text="Hello, John!"
|
text="Hello, John!"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert prompt_configurable.config_schema().schema() == {
|
assert prompt_configurable.config_schema(include=["configurable"]).schema() == {
|
||||||
"title": "RunnableConfigurableFieldsConfig",
|
"title": "RunnableConfigurableFieldsConfig",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
@ -638,7 +638,7 @@ def test_configurable_fields() -> None:
|
|||||||
|
|
||||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||||
|
|
||||||
assert chain_configurable.config_schema().schema() == {
|
assert chain_configurable.config_schema(include=["configurable"]).schema() == {
|
||||||
"title": "RunnableSequenceConfig",
|
"title": "RunnableSequenceConfig",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
@ -690,7 +690,9 @@ def test_configurable_fields() -> None:
|
|||||||
"llm3": "a",
|
"llm3": "a",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert chain_with_map_configurable.config_schema().schema() == {
|
assert chain_with_map_configurable.config_schema(
|
||||||
|
include=["configurable"]
|
||||||
|
).schema() == {
|
||||||
"title": "RunnableSequenceConfig",
|
"title": "RunnableSequenceConfig",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
@ -760,11 +762,17 @@ def test_configurable_fields_example() -> None:
|
|||||||
|
|
||||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||||
|
|
||||||
assert chain_configurable.config_schema().schema() == {
|
assert chain_configurable.config_schema(include=["configurable"]).schema() == {
|
||||||
"title": "RunnableSequenceConfig",
|
"title": "RunnableSequenceConfig",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
"definitions": {
|
"definitions": {
|
||||||
|
"LLM": {
|
||||||
|
"title": "LLM",
|
||||||
|
"description": "An enumeration.",
|
||||||
|
"enum": ["chat", "default"],
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
"Configurable": {
|
"Configurable": {
|
||||||
"title": "Configurable",
|
"title": "Configurable",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -772,10 +780,7 @@ def test_configurable_fields_example() -> None:
|
|||||||
"llm": {
|
"llm": {
|
||||||
"title": "LLM",
|
"title": "LLM",
|
||||||
"default": "default",
|
"default": "default",
|
||||||
"anyOf": [
|
"allOf": [{"$ref": "#/definitions/LLM"}],
|
||||||
{"enum": ["chat"], "type": "string"},
|
|
||||||
{"enum": ["default"], "type": "string"},
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
"llm_responses": {
|
"llm_responses": {
|
||||||
"title": "LLM Responses",
|
"title": "LLM Responses",
|
||||||
@ -791,7 +796,7 @@ def test_configurable_fields_example() -> None:
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user