Add validation for configurable keys passed to .with_config()

- Fix some typing issues found while doing that
This commit is contained in:
Nuno Campos 2023-10-17 08:50:31 +01:00
parent 754aca794f
commit 12596b9a9b
6 changed files with 57 additions and 35 deletions

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, List, Mapping, Optional, Union
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.output import ChatGeneration from langchain.schema.messages import BaseMessage
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
@ -19,7 +19,7 @@ class OpenAIFunction(TypedDict):
"""The parameters to the function.""" """The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]): class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
"""A runnable that routes to the selected function.""" """A runnable that routes to the selected function."""
functions: Optional[List[OpenAIFunction]] functions: Optional[List[OpenAIFunction]]

View File

@ -11,7 +11,7 @@ from typing import (
Union, Union,
) )
from langchain.schema.runnable.base import Input, Output, RunnableSerializable from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
@ -36,7 +36,7 @@ class PutLocalVar(RunnablePassthrough):
def _concat_put( def _concat_put(
self, self,
input: Input, input: Other,
*, *,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
replace: bool = False, replace: bool = False,
@ -68,35 +68,35 @@ class PutLocalVar(RunnablePassthrough):
f"{(type(self.key))}." f"{(type(self.key))}."
) )
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
self._concat_put(input, config=config, replace=True) self._concat_put(input, config=config, replace=True)
return super().invoke(input, config=config) return super().invoke(input, config=config)
async def ainvoke( async def ainvoke(
self, self,
input: Input, input: Other,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Input: ) -> Other:
self._concat_put(input, config=config, replace=True) self._concat_put(input, config=config, replace=True)
return await super().ainvoke(input, config=config) return await super().ainvoke(input, config=config)
def transform( def transform(
self, self,
input: Iterator[Input], input: Iterator[Other],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Input]: ) -> Iterator[Other]:
for chunk in super().transform(input, config=config): for chunk in super().transform(input, config=config):
self._concat_put(chunk, config=config) self._concat_put(chunk, config=config)
yield chunk yield chunk
async def atransform( async def atransform(
self, self,
input: AsyncIterator[Input], input: AsyncIterator[Other],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Input]: ) -> AsyncIterator[Other]:
async for chunk in super().atransform(input, config=config): async for chunk in super().atransform(input, config=config):
self._concat_put(chunk, config=config) self._concat_put(chunk, config=config)
yield chunk yield chunk

View File

@ -2296,6 +2296,25 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__(
self,
*,
bound: Runnable[Input, Output],
kwargs: Mapping[str, Any],
config: Optional[Mapping[str, Any]] = None,
**other_kwargs: Any,
) -> None:
config = config or {}
if configurable := config.get("configurable", None):
allowed_keys = set(s.id for s in bound.config_specs)
for key in configurable:
if key not in allowed_keys:
raise ValueError(
f"Configurable key '{key}' not found in runnable with"
f" config keys: {allowed_keys}"
)
super().__init__(bound=bound, kwargs=kwargs, config=config, **other_kwargs)
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> Type[Input]:
return self.bound.InputType return self.bound.InputType

View File

@ -22,7 +22,7 @@ from typing import (
from langchain.pydantic_v1 import BaseModel, create_model from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.base import ( from langchain.schema.runnable.base import (
Input, Other,
Runnable, Runnable,
RunnableParallel, RunnableParallel,
RunnableSerializable, RunnableSerializable,
@ -33,17 +33,17 @@ from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee from langchain.utils.iter import safetee
def identity(x: Input) -> Input: def identity(x: Other) -> Other:
"""An identity function""" """An identity function"""
return x return x
async def aidentity(x: Input) -> Input: async def aidentity(x: Other) -> Other:
"""An async identity function""" """An async identity function"""
return x return x
class RunnablePassthrough(RunnableSerializable[Input, Input]): class RunnablePassthrough(RunnableSerializable[Other, Other]):
"""A runnable to passthrough inputs unchanged or with additional keys. """A runnable to passthrough inputs unchanged or with additional keys.
This runnable behaves almost like the identity function, except that it This runnable behaves almost like the identity function, except that it
@ -100,20 +100,20 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20} # {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
""" """
input_type: Optional[Type[Input]] = None input_type: Optional[Type[Other]] = None
func: Optional[Callable[[Input], None]] = None func: Optional[Callable[[Other], None]] = None
afunc: Optional[Callable[[Input], Awaitable[None]]] = None afunc: Optional[Callable[[Other], Awaitable[None]]] = None
def __init__( def __init__(
self, self,
func: Optional[ func: Optional[
Union[Callable[[Input], None], Callable[[Input], Awaitable[None]]] Union[Callable[[Other], None], Callable[[Other], Awaitable[None]]]
] = None, ] = None,
afunc: Optional[Callable[[Input], Awaitable[None]]] = None, afunc: Optional[Callable[[Other], Awaitable[None]]] = None,
*, *,
input_type: Optional[Type[Input]] = None, input_type: Optional[Type[Other]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
@ -161,17 +161,17 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
""" """
return RunnableAssign(RunnableParallel(kwargs)) return RunnableAssign(RunnableParallel(kwargs))
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
if self.func is not None: if self.func is not None:
self.func(input) self.func(input)
return self._call_with_config(identity, input, config) return self._call_with_config(identity, input, config)
async def ainvoke( async def ainvoke(
self, self,
input: Input, input: Other,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Input: ) -> Other:
if self.afunc is not None: if self.afunc is not None:
await self.afunc(input, **kwargs) await self.afunc(input, **kwargs)
elif self.func is not None: elif self.func is not None:
@ -180,10 +180,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
def transform( def transform(
self, self,
input: Iterator[Input], input: Iterator[Other],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Input]: ) -> Iterator[Other]:
if self.func is None: if self.func is None:
for chunk in self._transform_stream_with_config(input, identity, config): for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk yield chunk
@ -202,10 +202,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
async def atransform( async def atransform(
self, self,
input: AsyncIterator[Input], input: AsyncIterator[Other],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Input]: ) -> AsyncIterator[Other]:
if self.afunc is None and self.func is None: if self.afunc is None and self.func is None:
async for chunk in self._atransform_stream_with_config( async for chunk in self._atransform_stream_with_config(
input, identity, config input, identity, config
@ -231,19 +231,19 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
def stream( def stream(
self, self,
input: Input, input: Other,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Input]: ) -> Iterator[Other]:
return self.transform(iter([input]), config, **kwargs) return self.transform(iter([input]), config, **kwargs)
async def astream( async def astream(
self, self,
input: Input, input: Other,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Input]: ) -> AsyncIterator[Other]:
async def input_aiter() -> AsyncIterator[Input]: async def input_aiter() -> AsyncIterator[Other]:
yield input yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs): async for chunk in self.atransform(input_aiter(), config, **kwargs):

View File

@ -24,9 +24,9 @@ from typing import (
Union, Union,
) )
Input = TypeVar("Input") Input = TypeVar("Input", contravariant=True)
# Output type should implement __concat__, as eg str, list, dict do # Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output") Output = TypeVar("Output", covariant=True)
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:

View File

@ -1002,6 +1002,9 @@ def test_configurable_fields_example() -> None:
}, },
} }
with pytest.raises(ValueError):
chain_configurable.with_config(configurable={"llm123": "chat"})
assert ( assert (
chain_configurable.with_config(configurable={"llm": "chat"}).invoke( chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
{"name": "John"} {"name": "John"}