mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
Add .configurable_fields() and .configurable_alternatives() to expose fields of a Runnable to be configured at runtime (#11282)
This commit is contained in:
parent
5e2d5047af
commit
4df3191092
@ -57,6 +57,8 @@ from langchain.schema.runnable.config import (
|
|||||||
)
|
)
|
||||||
from langchain.schema.runnable.utils import (
|
from langchain.schema.runnable.utils import (
|
||||||
AddableDict,
|
AddableDict,
|
||||||
|
ConfigurableField,
|
||||||
|
ConfigurableFieldSpec,
|
||||||
Input,
|
Input,
|
||||||
Output,
|
Output,
|
||||||
accepts_config,
|
accepts_config,
|
||||||
@ -64,6 +66,7 @@ from langchain.schema.runnable.utils import (
|
|||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
get_function_first_arg_dict_keys,
|
get_function_first_arg_dict_keys,
|
||||||
get_lambda_source,
|
get_lambda_source,
|
||||||
|
get_unique_config_specs,
|
||||||
indent_lines_after_first,
|
indent_lines_after_first,
|
||||||
)
|
)
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
@ -122,6 +125,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return []
|
||||||
|
|
||||||
def config_schema(
|
def config_schema(
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
@ -129,10 +136,28 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
include = include or []
|
include = include or []
|
||||||
|
config_specs = self.config_specs
|
||||||
|
configurable = (
|
||||||
|
create_model( # type: ignore[call-overload]
|
||||||
|
"Configurable",
|
||||||
|
**{
|
||||||
|
spec.id: (
|
||||||
|
spec.annotation,
|
||||||
|
Field(
|
||||||
|
spec.default, title=spec.name, description=spec.description
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for spec in config_specs
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if config_specs
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
self.__class__.__name__ + "Config",
|
self.__class__.__name__ + "Config",
|
||||||
__config__=_Config,
|
__config__=_Config,
|
||||||
|
**({"configurable": (configurable, None)} if configurable else {}),
|
||||||
**{
|
**{
|
||||||
field_name: (field_type, None)
|
field_name: (field_type, None)
|
||||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||||
@ -836,7 +861,32 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
|
|
||||||
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||||
pass
|
def configurable_fields(
|
||||||
|
self, **kwargs: ConfigurableField
|
||||||
|
) -> RunnableSerializable[Input, Output]:
|
||||||
|
from langchain.schema.runnable.configurable import RunnableConfigurableFields
|
||||||
|
|
||||||
|
for key in kwargs:
|
||||||
|
if key not in self.__fields__:
|
||||||
|
raise ValueError(
|
||||||
|
f"Configuration key {key} not found in {self}: "
|
||||||
|
"available keys are {self.__fields__.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunnableConfigurableFields(bound=self, fields=kwargs)
|
||||||
|
|
||||||
|
def configurable_alternatives(
|
||||||
|
self,
|
||||||
|
which: ConfigurableField,
|
||||||
|
**kwargs: Runnable[Input, Output],
|
||||||
|
) -> RunnableSerializable[Input, Output]:
|
||||||
|
from langchain.schema.runnable.configurable import (
|
||||||
|
RunnableConfigurableAlternatives,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunnableConfigurableAlternatives(
|
||||||
|
which=which, bound=self, alternatives=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunnableSequence(RunnableSerializable[Input, Output]):
|
class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||||
@ -879,6 +929,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return self.last.output_schema
|
return self.last.output_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return get_unique_config_specs(
|
||||||
|
spec for step in self.steps for spec in step.config_specs
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "\n| ".join(
|
return "\n| ".join(
|
||||||
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
|
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
|
||||||
@ -1388,6 +1444,12 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return get_unique_config_specs(
|
||||||
|
spec for step in self.steps.values() for spec in step.config_specs
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
map_for_repr = ",\n ".join(
|
map_for_repr = ",\n ".join(
|
||||||
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
|
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
|
||||||
@ -1985,6 +2047,10 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return self.bound.config_specs
|
||||||
|
|
||||||
def config_schema(
|
def config_schema(
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
@ -2062,6 +2128,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return self.bound.output_schema
|
return self.bound.output_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return self.bound.config_specs
|
||||||
|
|
||||||
def config_schema(
|
def config_schema(
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
|
@ -26,7 +26,12 @@ from langchain.schema.runnable.config import (
|
|||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.utils import Input, Output
|
from langchain.schema.runnable.utils import (
|
||||||
|
ConfigurableFieldSpec,
|
||||||
|
Input,
|
||||||
|
Output,
|
||||||
|
get_unique_config_specs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunnableBranch(RunnableSerializable[Input, Output]):
|
class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||||
@ -139,6 +144,18 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
return super().input_schema
|
return super().input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return get_unique_config_specs(
|
||||||
|
spec
|
||||||
|
for step in (
|
||||||
|
[self.default]
|
||||||
|
+ [r for _, r in self.branches]
|
||||||
|
+ [r for r, _ in self.branches]
|
||||||
|
)
|
||||||
|
for spec in step.config_specs
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
|
@ -34,6 +34,10 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyDict(TypedDict, total=False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RunnableConfig(TypedDict, total=False):
|
class RunnableConfig(TypedDict, total=False):
|
||||||
"""Configuration for a Runnable."""
|
"""Configuration for a Runnable."""
|
||||||
|
|
||||||
@ -78,6 +82,13 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
Maximum number of times a call can recurse. If not provided, defaults to 10.
|
Maximum number of times a call can recurse. If not provided, defaults to 10.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
configurable: Dict[str, Any]
|
||||||
|
"""
|
||||||
|
Runtime values for attributes previously made configurable by this Runnable,
|
||||||
|
or sub-Runnables, through .make_configurable(). Check .output_schema for
|
||||||
|
a description of the attributes that have been made configurable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||||
empty = RunnableConfig(
|
empty = RunnableConfig(
|
||||||
|
276
libs/langchain/langchain/schema/runnable/configurable.py
Normal file
276
libs/langchain/langchain/schema/runnable/configurable.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from langchain.schema.runnable.base import Runnable, RunnableSerializable
|
||||||
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
get_config_list,
|
||||||
|
get_executor_for_config,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.utils import (
|
||||||
|
ConfigurableField,
|
||||||
|
ConfigurableFieldSpec,
|
||||||
|
Input,
|
||||||
|
Output,
|
||||||
|
gather_with_concurrency,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||||
|
bound: RunnableSerializable[Input, Output]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Type[Input]:
|
||||||
|
return self.bound.InputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Type[Output]:
|
||||||
|
return self.bound.OutputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.bound.input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.bound.output_schema
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _prepare(
|
||||||
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
return self._prepare(config).invoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
return await self._prepare(config).ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> List[Output]:
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
prepared = [self._prepare(c) for c in configs]
|
||||||
|
|
||||||
|
if all(p is self.bound for p in prepared):
|
||||||
|
return self.bound.batch(
|
||||||
|
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
bound: Runnable[Input, Output],
|
||||||
|
input: Input,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Union[Output, Exception]:
|
||||||
|
if return_exceptions:
|
||||||
|
try:
|
||||||
|
return bound.invoke(input, config, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
else:
|
||||||
|
return bound.invoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
# If there's only one input, don't bother with the executor
|
||||||
|
if len(inputs) == 1:
|
||||||
|
return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])])
|
||||||
|
|
||||||
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
|
return cast(
|
||||||
|
List[Output], list(executor.map(invoke, prepared, inputs, configs))
|
||||||
|
)
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> List[Output]:
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
prepared = [self._prepare(c) for c in configs]
|
||||||
|
|
||||||
|
if all(p is self.bound for p in prepared):
|
||||||
|
return await self.bound.abatch(
|
||||||
|
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
bound: Runnable[Input, Output],
|
||||||
|
input: Input,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Union[Output, Exception]:
|
||||||
|
if return_exceptions:
|
||||||
|
try:
|
||||||
|
return await bound.ainvoke(input, config, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
else:
|
||||||
|
return await bound.ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
coros = map(ainvoke, prepared, inputs, configs)
|
||||||
|
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
return self._prepare(config).stream(input, config, **kwargs)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
async for chunk in self._prepare(config).astream(input, config, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self,
|
||||||
|
input: Iterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
return self._prepare(config).transform(input, config, **kwargs)
|
||||||
|
|
||||||
|
async def atransform(
|
||||||
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
async for chunk in self._prepare(config).atransform(input, config, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||||
|
fields: Dict[str, ConfigurableField]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return [
|
||||||
|
ConfigurableFieldSpec(
|
||||||
|
id=spec.id,
|
||||||
|
name=spec.name,
|
||||||
|
description=spec.description
|
||||||
|
or self.bound.__fields__[field_name].field_info.description,
|
||||||
|
annotation=spec.annotation
|
||||||
|
or self.bound.__fields__[field_name].annotation,
|
||||||
|
default=getattr(self.bound, field_name),
|
||||||
|
)
|
||||||
|
for field_name, spec in self.fields.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
def configurable_fields(
|
||||||
|
self, **kwargs: ConfigurableField
|
||||||
|
) -> RunnableSerializable[Input, Output]:
|
||||||
|
return self.bound.configurable_fields(**{**self.fields, **kwargs})
|
||||||
|
|
||||||
|
def _prepare(
|
||||||
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
config = config or {}
|
||||||
|
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||||
|
configurable = {
|
||||||
|
specs_by_id[k][0]: v
|
||||||
|
for k, v in config.get("configurable", {}).items()
|
||||||
|
if k in specs_by_id
|
||||||
|
}
|
||||||
|
|
||||||
|
if configurable:
|
||||||
|
return self.bound.__class__(**{**self.bound.dict(), **configurable})
|
||||||
|
else:
|
||||||
|
return self.bound
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||||
|
which: ConfigurableField
|
||||||
|
|
||||||
|
alternatives: Dict[str, RunnableSerializable[Input, Output]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
alt_keys = self.alternatives.keys()
|
||||||
|
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore
|
||||||
|
Literal["default"],
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
ConfigurableFieldSpec(
|
||||||
|
id=self.which.id,
|
||||||
|
name=self.which.name,
|
||||||
|
description=self.which.description,
|
||||||
|
annotation=Union[which_keys], # type: ignore
|
||||||
|
default="default",
|
||||||
|
),
|
||||||
|
*self.bound.config_specs,
|
||||||
|
] + [s for alt in self.alternatives.values() for s in alt.config_specs]
|
||||||
|
|
||||||
|
def configurable_fields(
|
||||||
|
self, **kwargs: ConfigurableField
|
||||||
|
) -> RunnableSerializable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
which=self.which,
|
||||||
|
bound=self.bound.configurable_fields(**kwargs),
|
||||||
|
alternatives=self.alternatives,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare(
|
||||||
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
config = config or {}
|
||||||
|
which = config.get("configurable", {}).get(self.which.id)
|
||||||
|
if not which:
|
||||||
|
return self.bound
|
||||||
|
elif which in self.alternatives:
|
||||||
|
return self.alternatives[which]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown alternative: {which}")
|
@ -22,7 +22,12 @@ from langchain.schema.runnable.config import (
|
|||||||
get_config_list,
|
get_config_list,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.utils import Input, Output
|
from langchain.schema.runnable.utils import (
|
||||||
|
ConfigurableFieldSpec,
|
||||||
|
Input,
|
||||||
|
Output,
|
||||||
|
get_unique_config_specs,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
|
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||||
@ -56,6 +61,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return self.runnable.output_schema
|
return self.runnable.output_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return get_unique_config_specs(
|
||||||
|
spec
|
||||||
|
for step in [self.runnable, *self.fallbacks]
|
||||||
|
for spec in step.config_specs
|
||||||
|
)
|
||||||
|
|
||||||
def config_schema(
|
def config_schema(
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
|
@ -11,6 +11,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -24,7 +25,7 @@ from langchain.schema.runnable.base import (
|
|||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
||||||
from langchain.schema.runnable.utils import AddableDict
|
from langchain.schema.runnable.utils import AddableDict, ConfigurableFieldSpec
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
from langchain.utils.iter import safetee
|
from langchain.utils.iter import safetee
|
||||||
|
|
||||||
@ -160,6 +161,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
|
|
||||||
return super().output_schema
|
return super().output_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return self.mapper.config_specs
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Dict[str, Any],
|
input: Dict[str, Any],
|
||||||
|
@ -8,6 +8,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -26,7 +27,11 @@ from langchain.schema.runnable.config import (
|
|||||||
get_config_list,
|
get_config_list,
|
||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.utils import gather_with_concurrency
|
from langchain.schema.runnable.utils import (
|
||||||
|
ConfigurableFieldSpec,
|
||||||
|
gather_with_concurrency,
|
||||||
|
get_unique_config_specs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RouterInput(TypedDict):
|
class RouterInput(TypedDict):
|
||||||
@ -49,6 +54,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
|||||||
|
|
||||||
runnables: Mapping[str, Runnable[Any, Output]]
|
runnables: Mapping[str, Runnable[Any, Output]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
return get_unique_config_specs(
|
||||||
|
spec for step in self.runnables.values() for spec in step.config_specs
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
|
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
|
||||||
|
@ -5,6 +5,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
from itertools import groupby
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterable,
|
AsyncIterable,
|
||||||
@ -13,8 +14,10 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -211,3 +214,39 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
|||||||
else:
|
else:
|
||||||
final = final + chunk
|
final = final + chunk
|
||||||
return final
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurableField(NamedTuple):
|
||||||
|
id: str
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
annotation: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurableFieldSpec(NamedTuple):
|
||||||
|
id: str
|
||||||
|
name: Optional[str]
|
||||||
|
description: Optional[str]
|
||||||
|
|
||||||
|
default: Any
|
||||||
|
annotation: Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_config_specs(
|
||||||
|
specs: Iterable[ConfigurableFieldSpec],
|
||||||
|
) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
|
||||||
|
unique: List[ConfigurableFieldSpec] = []
|
||||||
|
for id, dupes in grouped:
|
||||||
|
first = next(dupes)
|
||||||
|
others = list(dupes)
|
||||||
|
if len(others) == 0:
|
||||||
|
unique.append(first)
|
||||||
|
elif all(o == first for o in others):
|
||||||
|
unique.append(first)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"RunnableSequence contains conflicting config specs"
|
||||||
|
f"for {id}: {[first] + others}"
|
||||||
|
)
|
||||||
|
return unique
|
||||||
|
@ -30,6 +30,7 @@ from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
|||||||
from langchain.load.dump import dumpd, dumps
|
from langchain.load.dump import dumpd, dumps
|
||||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.prompts.base import StringPromptValue
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
ChatPromptValue,
|
ChatPromptValue,
|
||||||
@ -56,7 +57,7 @@ from langchain.schema.runnable import (
|
|||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.base import RunnableGenerator
|
from langchain.schema.runnable.base import ConfigurableField, RunnableGenerator
|
||||||
from langchain.schema.runnable.utils import add
|
from langchain.schema.runnable.utils import add
|
||||||
from langchain.tools.base import BaseTool, tool
|
from langchain.tools.base import BaseTool, tool
|
||||||
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
|
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
|
||||||
@ -143,6 +144,15 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
"title": "FakeRunnableOutput",
|
"title": "FakeRunnableOutput",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
}
|
}
|
||||||
|
assert fake.config_schema(include=["tags", "metadata", "run_name"]).schema() == {
|
||||||
|
"title": "FakeRunnableConfig",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"metadata": {"title": "Metadata", "type": "object"},
|
||||||
|
"run_name": {"title": "Run Name", "type": "string"},
|
||||||
|
"tags": {"items": {"type": "string"}, "title": "Tags", "type": "array"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
fake_bound = FakeRunnable().bind(a="b") # str -> int
|
fake_bound = FakeRunnable().bind(a="b") # str -> int
|
||||||
|
|
||||||
@ -538,6 +548,261 @@ def test_schema_chains() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_configurable_fields() -> None:
|
||||||
|
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
|
||||||
|
|
||||||
|
assert fake_llm.invoke("...") == "a"
|
||||||
|
|
||||||
|
fake_llm_configurable = fake_llm.configurable_fields(
|
||||||
|
responses=ConfigurableField(
|
||||||
|
id="llm_responses",
|
||||||
|
name="LLM Responses",
|
||||||
|
description="A list of fake responses for this LLM",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fake_llm_configurable.invoke("...") == "a"
|
||||||
|
|
||||||
|
assert fake_llm_configurable.config_schema().schema() == {
|
||||||
|
"title": "RunnableConfigurableFieldsConfig",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
|
"definitions": {
|
||||||
|
"Configurable": {
|
||||||
|
"title": "Configurable",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"llm_responses": {
|
||||||
|
"title": "LLM Responses",
|
||||||
|
"description": "A list of fake responses for this LLM",
|
||||||
|
"default": ["a"],
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fake_llm_configured = fake_llm_configurable.with_config(
|
||||||
|
configurable={"llm_responses": ["b"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fake_llm_configured.invoke("...") == "b"
|
||||||
|
|
||||||
|
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||||
|
|
||||||
|
assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!")
|
||||||
|
|
||||||
|
prompt_configurable = prompt.configurable_fields(
|
||||||
|
template=ConfigurableField(
|
||||||
|
id="prompt_template",
|
||||||
|
name="Prompt Template",
|
||||||
|
description="The prompt template for this chain",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue(
|
||||||
|
text="Hello, John!"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_configurable.config_schema().schema() == {
|
||||||
|
"title": "RunnableConfigurableFieldsConfig",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
|
"definitions": {
|
||||||
|
"Configurable": {
|
||||||
|
"title": "Configurable",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt_template": {
|
||||||
|
"title": "Prompt Template",
|
||||||
|
"description": "The prompt template for this chain",
|
||||||
|
"default": "Hello, {name}!",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_configured = prompt_configurable.with_config(
|
||||||
|
configurable={"prompt_template": "Hello, {name}! {name}!"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_configured.invoke({"name": "John"}) == StringPromptValue(
|
||||||
|
text="Hello, John! John!"
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
|
||||||
|
|
||||||
|
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||||
|
|
||||||
|
assert chain_configurable.config_schema().schema() == {
|
||||||
|
"title": "RunnableSequenceConfig",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
|
"definitions": {
|
||||||
|
"Configurable": {
|
||||||
|
"title": "Configurable",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"llm_responses": {
|
||||||
|
"title": "LLM Responses",
|
||||||
|
"description": "A list of fake responses for this LLM",
|
||||||
|
"default": ["a"],
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"prompt_template": {
|
||||||
|
"title": "Prompt Template",
|
||||||
|
"description": "The prompt template for this chain",
|
||||||
|
"default": "Hello, {name}!",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
chain_configurable.with_config(
|
||||||
|
configurable={
|
||||||
|
"prompt_template": "A very good morning to you, {name}!",
|
||||||
|
"llm_responses": ["c"],
|
||||||
|
}
|
||||||
|
).invoke({"name": "John"})
|
||||||
|
== "c"
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_with_map_configurable: Runnable = prompt_configurable | {
|
||||||
|
"llm1": fake_llm_configurable | StrOutputParser(),
|
||||||
|
"llm2": fake_llm_configurable | StrOutputParser(),
|
||||||
|
"llm3": fake_llm.configurable_fields(
|
||||||
|
responses=ConfigurableField("other_responses")
|
||||||
|
)
|
||||||
|
| StrOutputParser(),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert chain_with_map_configurable.invoke({"name": "John"}) == {
|
||||||
|
"llm1": "a",
|
||||||
|
"llm2": "a",
|
||||||
|
"llm3": "a",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert chain_with_map_configurable.config_schema().schema() == {
|
||||||
|
"title": "RunnableSequenceConfig",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
|
"definitions": {
|
||||||
|
"Configurable": {
|
||||||
|
"title": "Configurable",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"llm_responses": {
|
||||||
|
"title": "LLM Responses",
|
||||||
|
"description": "A list of fake responses for this LLM",
|
||||||
|
"default": ["a"],
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"other_responses": {
|
||||||
|
"title": "Other Responses",
|
||||||
|
"default": ["a"],
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"prompt_template": {
|
||||||
|
"title": "Prompt Template",
|
||||||
|
"description": "The prompt template for this chain",
|
||||||
|
"default": "Hello, {name}!",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert chain_with_map_configurable.with_config(
|
||||||
|
configurable={
|
||||||
|
"prompt_template": "A very good morning to you, {name}!",
|
||||||
|
"llm_responses": ["c"],
|
||||||
|
"other_responses": ["d"],
|
||||||
|
}
|
||||||
|
).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_configurable_fields_example() -> None:
|
||||||
|
fake_llm = (
|
||||||
|
FakeListLLM(responses=["a"])
|
||||||
|
.configurable_fields(
|
||||||
|
responses=ConfigurableField(
|
||||||
|
id="llm_responses",
|
||||||
|
name="LLM Responses",
|
||||||
|
description="A list of fake responses for this LLM",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.configurable_alternatives(
|
||||||
|
ConfigurableField(id="llm", name="LLM"),
|
||||||
|
chat=FakeListChatModel(responses=["b"]) | StrOutputParser(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
|
||||||
|
template=ConfigurableField(
|
||||||
|
id="prompt_template",
|
||||||
|
name="Prompt Template",
|
||||||
|
description="The prompt template for this chain",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_configurable = prompt | fake_llm
|
||||||
|
|
||||||
|
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||||
|
|
||||||
|
assert chain_configurable.config_schema().schema() == {
|
||||||
|
"title": "RunnableSequenceConfig",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||||
|
"definitions": {
|
||||||
|
"Configurable": {
|
||||||
|
"title": "Configurable",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"llm": {
|
||||||
|
"title": "LLM",
|
||||||
|
"default": "default",
|
||||||
|
"anyOf": [
|
||||||
|
{"enum": ["chat"], "type": "string"},
|
||||||
|
{"enum": ["default"], "type": "string"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"llm_responses": {
|
||||||
|
"title": "LLM Responses",
|
||||||
|
"description": "A list of fake responses for this LLM",
|
||||||
|
"default": ["a"],
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"prompt_template": {
|
||||||
|
"title": "Prompt Template",
|
||||||
|
"description": "The prompt template for this chain",
|
||||||
|
"default": "Hello, {name}!",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
|
||||||
|
{"name": "John"}
|
||||||
|
)
|
||||||
|
== "b"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_with_config(mocker: MockerFixture) -> None:
|
async def test_with_config(mocker: MockerFixture) -> None:
|
||||||
fake = FakeRunnable()
|
fake = FakeRunnable()
|
||||||
|
Loading…
Reference in New Issue
Block a user