mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Add validation for configurable keys passed to .with_config()
- Fix some typing issues found while doing that
This commit is contained in:
parent
754aca794f
commit
12596b9a9b
@ -4,7 +4,7 @@ from typing import Any, Callable, List, Mapping, Optional, Union
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema.output import ChatGeneration
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ class OpenAIFunction(TypedDict):
|
||||
"""The parameters to the function."""
|
||||
|
||||
|
||||
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
|
||||
class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
|
||||
"""A runnable that routes to the selected function."""
|
||||
|
||||
functions: Optional[List[OpenAIFunction]]
|
||||
|
@ -11,7 +11,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
|
||||
from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
|
||||
@ -36,7 +36,7 @@ class PutLocalVar(RunnablePassthrough):
|
||||
|
||||
def _concat_put(
|
||||
self,
|
||||
input: Input,
|
||||
input: Other,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
replace: bool = False,
|
||||
@ -68,35 +68,35 @@ class PutLocalVar(RunnablePassthrough):
|
||||
f"{(type(self.key))}."
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
|
||||
self._concat_put(input, config=config, replace=True)
|
||||
return super().invoke(input, config=config)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
input: Other,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Input:
|
||||
) -> Other:
|
||||
self._concat_put(input, config=config, replace=True)
|
||||
return await super().ainvoke(input, config=config)
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
input: Iterator[Other],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Input]:
|
||||
) -> Iterator[Other]:
|
||||
for chunk in super().transform(input, config=config):
|
||||
self._concat_put(chunk, config=config)
|
||||
yield chunk
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
input: AsyncIterator[Other],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Input]:
|
||||
) -> AsyncIterator[Other]:
|
||||
async for chunk in super().atransform(input, config=config):
|
||||
self._concat_put(chunk, config=config)
|
||||
yield chunk
|
||||
|
@ -2296,6 +2296,25 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bound: Runnable[Input, Output],
|
||||
kwargs: Mapping[str, Any],
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
config = config or {}
|
||||
if configurable := config.get("configurable", None):
|
||||
allowed_keys = set(s.id for s in bound.config_specs)
|
||||
for key in configurable:
|
||||
if key not in allowed_keys:
|
||||
raise ValueError(
|
||||
f"Configurable key '{key}' not found in runnable with"
|
||||
f" config keys: {allowed_keys}"
|
||||
)
|
||||
super().__init__(bound=bound, kwargs=kwargs, config=config, **other_kwargs)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.bound.InputType
|
||||
|
@ -22,7 +22,7 @@ from typing import (
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, create_model
|
||||
from langchain.schema.runnable.base import (
|
||||
Input,
|
||||
Other,
|
||||
Runnable,
|
||||
RunnableParallel,
|
||||
RunnableSerializable,
|
||||
@ -33,17 +33,17 @@ from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
|
||||
|
||||
def identity(x: Input) -> Input:
|
||||
def identity(x: Other) -> Other:
|
||||
"""An identity function"""
|
||||
return x
|
||||
|
||||
|
||||
async def aidentity(x: Input) -> Input:
|
||||
async def aidentity(x: Other) -> Other:
|
||||
"""An async identity function"""
|
||||
return x
|
||||
|
||||
|
||||
class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||
class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
"""A runnable to passthrough inputs unchanged or with additional keys.
|
||||
|
||||
This runnable behaves almost like the identity function, except that it
|
||||
@ -100,20 +100,20 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
|
||||
"""
|
||||
|
||||
input_type: Optional[Type[Input]] = None
|
||||
input_type: Optional[Type[Other]] = None
|
||||
|
||||
func: Optional[Callable[[Input], None]] = None
|
||||
func: Optional[Callable[[Other], None]] = None
|
||||
|
||||
afunc: Optional[Callable[[Input], Awaitable[None]]] = None
|
||||
afunc: Optional[Callable[[Other], Awaitable[None]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Optional[
|
||||
Union[Callable[[Input], None], Callable[[Input], Awaitable[None]]]
|
||||
Union[Callable[[Other], None], Callable[[Other], Awaitable[None]]]
|
||||
] = None,
|
||||
afunc: Optional[Callable[[Input], Awaitable[None]]] = None,
|
||||
afunc: Optional[Callable[[Other], Awaitable[None]]] = None,
|
||||
*,
|
||||
input_type: Optional[Type[Input]] = None,
|
||||
input_type: Optional[Type[Other]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
@ -161,17 +161,17 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
|
||||
if self.func is not None:
|
||||
self.func(input)
|
||||
return self._call_with_config(identity, input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
input: Other,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Input:
|
||||
) -> Other:
|
||||
if self.afunc is not None:
|
||||
await self.afunc(input, **kwargs)
|
||||
elif self.func is not None:
|
||||
@ -180,10 +180,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
input: Iterator[Other],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Input]:
|
||||
) -> Iterator[Other]:
|
||||
if self.func is None:
|
||||
for chunk in self._transform_stream_with_config(input, identity, config):
|
||||
yield chunk
|
||||
@ -202,10 +202,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
input: AsyncIterator[Other],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Input]:
|
||||
) -> AsyncIterator[Other]:
|
||||
if self.afunc is None and self.func is None:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, identity, config
|
||||
@ -231,19 +231,19 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
input: Other,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Input]:
|
||||
) -> Iterator[Other]:
|
||||
return self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
input: Other,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Input]:
|
||||
async def input_aiter() -> AsyncIterator[Input]:
|
||||
) -> AsyncIterator[Other]:
|
||||
async def input_aiter() -> AsyncIterator[Other]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||
|
@ -24,9 +24,9 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
Input = TypeVar("Input")
|
||||
Input = TypeVar("Input", contravariant=True)
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output")
|
||||
Output = TypeVar("Output", covariant=True)
|
||||
|
||||
|
||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
|
@ -1002,6 +1002,9 @@ def test_configurable_fields_example() -> None:
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chain_configurable.with_config(configurable={"llm123": "chat"})
|
||||
|
||||
assert (
|
||||
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
|
||||
{"name": "John"}
|
||||
|
Loading…
Reference in New Issue
Block a user