mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
Add validation for configurable keys passed to .with_config()
- Fix some typing issues found while doing that
This commit is contained in:
parent
754aca794f
commit
12596b9a9b
@ -4,7 +4,7 @@ from typing import Any, Callable, List, Mapping, Optional, Union
|
|||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain.schema.output import ChatGeneration
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ class OpenAIFunction(TypedDict):
|
|||||||
"""The parameters to the function."""
|
"""The parameters to the function."""
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
|
class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
|
||||||
"""A runnable that routes to the selected function."""
|
"""A runnable that routes to the selected function."""
|
||||||
|
|
||||||
functions: Optional[List[OpenAIFunction]]
|
functions: Optional[List[OpenAIFunction]]
|
||||||
|
@ -11,7 +11,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
|
from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
|
|
||||||
def _concat_put(
|
def _concat_put(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Other,
|
||||||
*,
|
*,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
@ -68,35 +68,35 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
f"{(type(self.key))}."
|
f"{(type(self.key))}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
|
||||||
self._concat_put(input, config=config, replace=True)
|
self._concat_put(input, config=config, replace=True)
|
||||||
return super().invoke(input, config=config)
|
return super().invoke(input, config=config)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Other,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Input:
|
) -> Other:
|
||||||
self._concat_put(input, config=config, replace=True)
|
self._concat_put(input, config=config, replace=True)
|
||||||
return await super().ainvoke(input, config=config)
|
return await super().ainvoke(input, config=config)
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
input: Iterator[Other],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Input]:
|
) -> Iterator[Other]:
|
||||||
for chunk in super().transform(input, config=config):
|
for chunk in super().transform(input, config=config):
|
||||||
self._concat_put(chunk, config=config)
|
self._concat_put(chunk, config=config)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
input: AsyncIterator[Other],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Input]:
|
) -> AsyncIterator[Other]:
|
||||||
async for chunk in super().atransform(input, config=config):
|
async for chunk in super().atransform(input, config=config):
|
||||||
self._concat_put(chunk, config=config)
|
self._concat_put(chunk, config=config)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
@ -2296,6 +2296,25 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
bound: Runnable[Input, Output],
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
|
config: Optional[Mapping[str, Any]] = None,
|
||||||
|
**other_kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
config = config or {}
|
||||||
|
if configurable := config.get("configurable", None):
|
||||||
|
allowed_keys = set(s.id for s in bound.config_specs)
|
||||||
|
for key in configurable:
|
||||||
|
if key not in allowed_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"Configurable key '{key}' not found in runnable with"
|
||||||
|
f" config keys: {allowed_keys}"
|
||||||
|
)
|
||||||
|
super().__init__(bound=bound, kwargs=kwargs, config=config, **other_kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Type[Input]:
|
def InputType(self) -> Type[Input]:
|
||||||
return self.bound.InputType
|
return self.bound.InputType
|
||||||
|
@ -22,7 +22,7 @@ from typing import (
|
|||||||
|
|
||||||
from langchain.pydantic_v1 import BaseModel, create_model
|
from langchain.pydantic_v1 import BaseModel, create_model
|
||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Input,
|
Other,
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableParallel,
|
RunnableParallel,
|
||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
@ -33,17 +33,17 @@ from langchain.utils.aiter import atee, py_anext
|
|||||||
from langchain.utils.iter import safetee
|
from langchain.utils.iter import safetee
|
||||||
|
|
||||||
|
|
||||||
def identity(x: Input) -> Input:
|
def identity(x: Other) -> Other:
|
||||||
"""An identity function"""
|
"""An identity function"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
async def aidentity(x: Input) -> Input:
|
async def aidentity(x: Other) -> Other:
|
||||||
"""An async identity function"""
|
"""An async identity function"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||||
"""A runnable to passthrough inputs unchanged or with additional keys.
|
"""A runnable to passthrough inputs unchanged or with additional keys.
|
||||||
|
|
||||||
This runnable behaves almost like the identity function, except that it
|
This runnable behaves almost like the identity function, except that it
|
||||||
@ -100,20 +100,20 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
|||||||
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
|
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_type: Optional[Type[Input]] = None
|
input_type: Optional[Type[Other]] = None
|
||||||
|
|
||||||
func: Optional[Callable[[Input], None]] = None
|
func: Optional[Callable[[Other], None]] = None
|
||||||
|
|
||||||
afunc: Optional[Callable[[Input], Awaitable[None]]] = None
|
afunc: Optional[Callable[[Other], Awaitable[None]]] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
func: Optional[
|
func: Optional[
|
||||||
Union[Callable[[Input], None], Callable[[Input], Awaitable[None]]]
|
Union[Callable[[Other], None], Callable[[Other], Awaitable[None]]]
|
||||||
] = None,
|
] = None,
|
||||||
afunc: Optional[Callable[[Input], Awaitable[None]]] = None,
|
afunc: Optional[Callable[[Other], Awaitable[None]]] = None,
|
||||||
*,
|
*,
|
||||||
input_type: Optional[Type[Input]] = None,
|
input_type: Optional[Type[Other]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
@ -161,17 +161,17 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
|||||||
"""
|
"""
|
||||||
return RunnableAssign(RunnableParallel(kwargs))
|
return RunnableAssign(RunnableParallel(kwargs))
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
|
||||||
if self.func is not None:
|
if self.func is not None:
|
||||||
self.func(input)
|
self.func(input)
|
||||||
return self._call_with_config(identity, input, config)
|
return self._call_with_config(identity, input, config)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Other,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Input:
|
) -> Other:
|
||||||
if self.afunc is not None:
|
if self.afunc is not None:
|
||||||
await self.afunc(input, **kwargs)
|
await self.afunc(input, **kwargs)
|
||||||
elif self.func is not None:
|
elif self.func is not None:
|
||||||
@ -180,10 +180,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
|||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
input: Iterator[Other],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Input]:
|
) -> Iterator[Other]:
|
||||||
if self.func is None:
|
if self.func is None:
|
||||||
for chunk in self._transform_stream_with_config(input, identity, config):
|
for chunk in self._transform_stream_with_config(input, identity, config):
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -202,10 +202,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
|||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
input: AsyncIterator[Other],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Input]:
|
) -> AsyncIterator[Other]:
|
||||||
if self.afunc is None and self.func is None:
|
if self.afunc is None and self.func is None:
|
||||||
async for chunk in self._atransform_stream_with_config(
|
async for chunk in self._atransform_stream_with_config(
|
||||||
input, identity, config
|
input, identity, config
|
||||||
@ -231,19 +231,19 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
|||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Other,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Input]:
|
) -> Iterator[Other]:
|
||||||
return self.transform(iter([input]), config, **kwargs)
|
return self.transform(iter([input]), config, **kwargs)
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Other,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Input]:
|
) -> AsyncIterator[Other]:
|
||||||
async def input_aiter() -> AsyncIterator[Input]:
|
async def input_aiter() -> AsyncIterator[Other]:
|
||||||
yield input
|
yield input
|
||||||
|
|
||||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||||
|
@ -24,9 +24,9 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
Input = TypeVar("Input")
|
Input = TypeVar("Input", contravariant=True)
|
||||||
# Output type should implement __concat__, as eg str, list, dict do
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
Output = TypeVar("Output")
|
Output = TypeVar("Output", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||||
|
@ -1002,6 +1002,9 @@ def test_configurable_fields_example() -> None:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain_configurable.with_config(configurable={"llm123": "chat"})
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
|
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
|
||||||
{"name": "John"}
|
{"name": "John"}
|
||||||
|
Loading…
Reference in New Issue
Block a user