Add validation for configurable keys passed to .with_config()

- Fix some typing issues found while doing that
This commit is contained in:
Nuno Campos 2023-10-17 08:50:31 +01:00
parent 754aca794f
commit 12596b9a9b
6 changed files with 57 additions and 35 deletions

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, List, Mapping, Optional, Union
from typing_extensions import TypedDict
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.output import ChatGeneration
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
@ -19,7 +19,7 @@ class OpenAIFunction(TypedDict):
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
"""A runnable that routes to the selected function."""
functions: Optional[List[OpenAIFunction]]

View File

@ -11,7 +11,7 @@ from typing import (
Union,
)
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
@ -36,7 +36,7 @@ class PutLocalVar(RunnablePassthrough):
def _concat_put(
self,
input: Input,
input: Other,
*,
config: Optional[RunnableConfig] = None,
replace: bool = False,
@ -68,35 +68,35 @@ class PutLocalVar(RunnablePassthrough):
f"{(type(self.key))}."
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
self._concat_put(input, config=config, replace=True)
return super().invoke(input, config=config)
async def ainvoke(
self,
input: Input,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Input:
) -> Other:
self._concat_put(input, config=config, replace=True)
return await super().ainvoke(input, config=config)
def transform(
self,
input: Iterator[Input],
input: Iterator[Other],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Input]:
) -> Iterator[Other]:
for chunk in super().transform(input, config=config):
self._concat_put(chunk, config=config)
yield chunk
async def atransform(
self,
input: AsyncIterator[Input],
input: AsyncIterator[Other],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Input]:
) -> AsyncIterator[Other]:
async for chunk in super().atransform(input, config=config):
self._concat_put(chunk, config=config)
yield chunk

View File

@ -2296,6 +2296,25 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
class Config:
arbitrary_types_allowed = True
def __init__(
self,
*,
bound: Runnable[Input, Output],
kwargs: Mapping[str, Any],
config: Optional[Mapping[str, Any]] = None,
**other_kwargs: Any,
) -> None:
config = config or {}
if configurable := config.get("configurable", None):
allowed_keys = set(s.id for s in bound.config_specs)
for key in configurable:
if key not in allowed_keys:
raise ValueError(
f"Configurable key '{key}' not found in runnable with"
f" config keys: {allowed_keys}"
)
super().__init__(bound=bound, kwargs=kwargs, config=config, **other_kwargs)
@property
def InputType(self) -> Type[Input]:
return self.bound.InputType

View File

@ -22,7 +22,7 @@ from typing import (
from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.base import (
Input,
Other,
Runnable,
RunnableParallel,
RunnableSerializable,
@ -33,17 +33,17 @@ from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
def identity(x: Input) -> Input:
def identity(x: Other) -> Other:
"""An identity function"""
return x
async def aidentity(x: Input) -> Input:
async def aidentity(x: Other) -> Other:
"""An async identity function"""
return x
class RunnablePassthrough(RunnableSerializable[Input, Input]):
class RunnablePassthrough(RunnableSerializable[Other, Other]):
"""A runnable to passthrough inputs unchanged or with additional keys.
This runnable behaves almost like the identity function, except that it
@ -100,20 +100,20 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
"""
input_type: Optional[Type[Input]] = None
input_type: Optional[Type[Other]] = None
func: Optional[Callable[[Input], None]] = None
func: Optional[Callable[[Other], None]] = None
afunc: Optional[Callable[[Input], Awaitable[None]]] = None
afunc: Optional[Callable[[Other], Awaitable[None]]] = None
def __init__(
self,
func: Optional[
Union[Callable[[Input], None], Callable[[Input], Awaitable[None]]]
Union[Callable[[Other], None], Callable[[Other], Awaitable[None]]]
] = None,
afunc: Optional[Callable[[Input], Awaitable[None]]] = None,
afunc: Optional[Callable[[Other], Awaitable[None]]] = None,
*,
input_type: Optional[Type[Input]] = None,
input_type: Optional[Type[Other]] = None,
**kwargs: Any,
) -> None:
if inspect.iscoroutinefunction(func):
@ -161,17 +161,17 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
"""
return RunnableAssign(RunnableParallel(kwargs))
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
if self.func is not None:
self.func(input)
return self._call_with_config(identity, input, config)
async def ainvoke(
self,
input: Input,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Input:
) -> Other:
if self.afunc is not None:
await self.afunc(input, **kwargs)
elif self.func is not None:
@ -180,10 +180,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
def transform(
self,
input: Iterator[Input],
input: Iterator[Other],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Input]:
) -> Iterator[Other]:
if self.func is None:
for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk
@ -202,10 +202,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
async def atransform(
self,
input: AsyncIterator[Input],
input: AsyncIterator[Other],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Input]:
) -> AsyncIterator[Other]:
if self.afunc is None and self.func is None:
async for chunk in self._atransform_stream_with_config(
input, identity, config
@ -231,19 +231,19 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
def stream(
self,
input: Input,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Input]:
) -> Iterator[Other]:
return self.transform(iter([input]), config, **kwargs)
async def astream(
self,
input: Input,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Input]:
async def input_aiter() -> AsyncIterator[Input]:
) -> AsyncIterator[Other]:
async def input_aiter() -> AsyncIterator[Other]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):

View File

@ -24,9 +24,9 @@ from typing import (
Union,
)
Input = TypeVar("Input")
Input = TypeVar("Input", contravariant=True)
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
Output = TypeVar("Output", covariant=True)
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:

View File

@ -1002,6 +1002,9 @@ def test_configurable_fields_example() -> None:
},
}
with pytest.raises(ValueError):
chain_configurable.with_config(configurable={"llm123": "chat"})
assert (
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
{"name": "John"}