Use an enum for configurable_alternatives to make the generated json schema nicer (#11350)

This commit is contained in:
Nuno Campos 2023-10-04 16:32:41 +01:00 committed by GitHub
parent b499de2926
commit b0893c7c6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 52 deletions

View File

@ -129,9 +129,7 @@ class Runnable(Generic[Input, Output], ABC):
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return [] return []
def config_schema( def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
class _Config: class _Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -150,7 +148,7 @@ class Runnable(Generic[Input, Output], ABC):
for spec in config_specs for spec in config_specs
}, },
) )
if config_specs if config_specs and "configurable" in include
else None else None
) )
@ -161,7 +159,7 @@ class Runnable(Generic[Input, Output], ABC):
**{ **{
field_name: (field_type, None) field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items() for field_name, field_type in RunnableConfig.__annotations__.items()
if field_name in include if field_name in [i for i in include if i != "configurable"]
}, },
) )
@ -873,7 +871,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
"available keys are {self.__fields__.keys()}" "available keys are {self.__fields__.keys()}"
) )
return RunnableConfigurableFields(bound=self, fields=kwargs) return RunnableConfigurableFields(default=self, fields=kwargs)
def configurable_alternatives( def configurable_alternatives(
self, self,
@ -885,7 +883,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
) )
return RunnableConfigurableAlternatives( return RunnableConfigurableAlternatives(
which=which, bound=self, alternatives=kwargs which=which, default=self, alternatives=kwargs
) )
@ -2051,9 +2049,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs return self.bound.config_specs
def config_schema( def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include) return self.bound.config_schema(include=include)
@classmethod @classmethod
@ -2132,9 +2128,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs return self.bound.config_specs
def config_schema( def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include) return self.bound.config_schema(include=include)
@classmethod @classmethod

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import enum
from abc import abstractmethod from abc import abstractmethod
from typing import ( from typing import (
Any, Any,
@ -7,7 +8,6 @@ from typing import (
Dict, Dict,
Iterator, Iterator,
List, List,
Literal,
Optional, Optional,
Sequence, Sequence,
Type, Type,
@ -32,7 +32,7 @@ from langchain.schema.runnable.utils import (
class DynamicRunnable(RunnableSerializable[Input, Output]): class DynamicRunnable(RunnableSerializable[Input, Output]):
bound: RunnableSerializable[Input, Output] default: RunnableSerializable[Input, Output]
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -47,19 +47,19 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> Type[Input]:
return self.bound.InputType return self.default.InputType
@property @property
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.bound.OutputType return self.default.OutputType
@property @property
def input_schema(self) -> Type[BaseModel]: def input_schema(self) -> Type[BaseModel]:
return self.bound.input_schema return self.default.input_schema
@property @property
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema return self.default.output_schema
@abstractmethod @abstractmethod
def _prepare( def _prepare(
@ -88,8 +88,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs] prepared = [self._prepare(c) for c in configs]
if all(p is self.bound for p in prepared): if all(p is self.default for p in prepared):
return self.bound.batch( return self.default.batch(
inputs, config, return_exceptions=return_exceptions, **kwargs inputs, config, return_exceptions=return_exceptions, **kwargs
) )
@ -131,8 +131,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs] prepared = [self._prepare(c) for c in configs]
if all(p is self.bound for p in prepared): if all(p is self.default for p in prepared):
return await self.bound.abatch( return await self.default.abatch(
inputs, config, return_exceptions=return_exceptions, **kwargs inputs, config, return_exceptions=return_exceptions, **kwargs
) )
@ -202,10 +202,10 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
id=spec.id, id=spec.id,
name=spec.name, name=spec.name,
description=spec.description description=spec.description
or self.bound.__fields__[field_name].field_info.description, or self.default.__fields__[field_name].field_info.description,
annotation=spec.annotation annotation=spec.annotation
or self.bound.__fields__[field_name].annotation, or self.default.__fields__[field_name].annotation,
default=getattr(self.bound, field_name), default=getattr(self.default, field_name),
) )
for field_name, spec in self.fields.items() for field_name, spec in self.fields.items()
] ]
@ -213,7 +213,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
def configurable_fields( def configurable_fields(
self, **kwargs: ConfigurableField self, **kwargs: ConfigurableField
) -> RunnableSerializable[Input, Output]: ) -> RunnableSerializable[Input, Output]:
return self.bound.configurable_fields(**{**self.fields, **kwargs}) return self.default.configurable_fields(**{**self.fields, **kwargs})
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
@ -227,9 +227,14 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
} }
if configurable: if configurable:
return self.bound.__class__(**{**self.bound.dict(), **configurable}) return self.default.__class__(**{**self.default.dict(), **configurable})
else: else:
return self.bound return self.default
# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
pass
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
@ -237,21 +242,23 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
alternatives: Dict[str, RunnableSerializable[Input, Output]] alternatives: Dict[str, RunnableSerializable[Input, Output]]
default_key: str = "default"
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
alt_keys = self.alternatives.keys() which_enum = StrEnum( # type: ignore[call-overload]
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore self.which.name or self.which.id,
Literal["default"], ((v, v) for v in list(self.alternatives.keys()) + [self.default_key]),
) )
return [ return [
ConfigurableFieldSpec( ConfigurableFieldSpec(
id=self.which.id, id=self.which.id,
name=self.which.name, name=self.which.name,
description=self.which.description, description=self.which.description,
annotation=Union[which_keys], # type: ignore annotation=which_enum,
default="default", default=self.default_key,
), ),
*self.bound.config_specs, *self.default.config_specs,
] + [s for alt in self.alternatives.values() for s in alt.config_specs] ] + [s for alt in self.alternatives.values() for s in alt.config_specs]
def configurable_fields( def configurable_fields(
@ -259,7 +266,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
) -> RunnableSerializable[Input, Output]: ) -> RunnableSerializable[Input, Output]:
return self.__class__( return self.__class__(
which=self.which, which=self.which,
bound=self.bound.configurable_fields(**kwargs), default=self.default.configurable_fields(**kwargs),
alternatives=self.alternatives, alternatives=self.alternatives,
) )
@ -267,9 +274,9 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
config = config or {} config = config or {}
which = config.get("configurable", {}).get(self.which.id) which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
if not which: if which == self.default_key:
return self.bound return self.default
elif which in self.alternatives: elif which in self.alternatives:
return self.alternatives[which] return self.alternatives[which]
else: else:

View File

@ -69,9 +69,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for spec in step.config_specs for spec in step.config_specs
) )
def config_schema( def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.runnable.config_schema(include=include) return self.runnable.config_schema(include=include)
@classmethod @classmethod

View File

@ -563,7 +563,7 @@ def test_configurable_fields() -> None:
assert fake_llm_configurable.invoke("...") == "a" assert fake_llm_configurable.invoke("...") == "a"
assert fake_llm_configurable.config_schema().schema() == { assert fake_llm_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableConfigurableFieldsConfig", "title": "RunnableConfigurableFieldsConfig",
"type": "object", "type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, "properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -606,7 +606,7 @@ def test_configurable_fields() -> None:
text="Hello, John!" text="Hello, John!"
) )
assert prompt_configurable.config_schema().schema() == { assert prompt_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableConfigurableFieldsConfig", "title": "RunnableConfigurableFieldsConfig",
"type": "object", "type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, "properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -638,7 +638,7 @@ def test_configurable_fields() -> None:
assert chain_configurable.invoke({"name": "John"}) == "a" assert chain_configurable.invoke({"name": "John"}) == "a"
assert chain_configurable.config_schema().schema() == { assert chain_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableSequenceConfig", "title": "RunnableSequenceConfig",
"type": "object", "type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, "properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -690,7 +690,9 @@ def test_configurable_fields() -> None:
"llm3": "a", "llm3": "a",
} }
assert chain_with_map_configurable.config_schema().schema() == { assert chain_with_map_configurable.config_schema(
include=["configurable"]
).schema() == {
"title": "RunnableSequenceConfig", "title": "RunnableSequenceConfig",
"type": "object", "type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, "properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -760,11 +762,17 @@ def test_configurable_fields_example() -> None:
assert chain_configurable.invoke({"name": "John"}) == "a" assert chain_configurable.invoke({"name": "John"}) == "a"
assert chain_configurable.config_schema().schema() == { assert chain_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableSequenceConfig", "title": "RunnableSequenceConfig",
"type": "object", "type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, "properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": { "definitions": {
"LLM": {
"title": "LLM",
"description": "An enumeration.",
"enum": ["chat", "default"],
"type": "string",
},
"Configurable": { "Configurable": {
"title": "Configurable", "title": "Configurable",
"type": "object", "type": "object",
@ -772,10 +780,7 @@ def test_configurable_fields_example() -> None:
"llm": { "llm": {
"title": "LLM", "title": "LLM",
"default": "default", "default": "default",
"anyOf": [ "allOf": [{"$ref": "#/definitions/LLM"}],
{"enum": ["chat"], "type": "string"},
{"enum": ["default"], "type": "string"},
],
}, },
"llm_responses": { "llm_responses": {
"title": "LLM Responses", "title": "LLM Responses",
@ -791,7 +796,7 @@ def test_configurable_fields_example() -> None:
"type": "string", "type": "string",
}, },
}, },
} },
}, },
} }