diff --git a/libs/langchain/langchain/runnables/openai_functions.py b/libs/langchain/langchain/runnables/openai_functions.py index 55c9765d20c..1ee9f44b971 100644 --- a/libs/langchain/langchain/runnables/openai_functions.py +++ b/libs/langchain/langchain/runnables/openai_functions.py @@ -4,7 +4,7 @@ from typing import Any, Callable, List, Mapping, Optional, Union from typing_extensions import TypedDict from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser -from langchain.schema.output import ChatGeneration +from langchain.schema.messages import BaseMessage from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding @@ -19,7 +19,7 @@ class OpenAIFunction(TypedDict): """The parameters to the function.""" -class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]): +class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]): """A runnable that routes to the selected function.""" functions: Optional[List[OpenAIFunction]] diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py index e2fe8541148..a794ea6d5c0 100644 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ b/libs/langchain/langchain/schema/runnable/_locals.py @@ -11,7 +11,7 @@ from typing import ( Union, ) -from langchain.schema.runnable.base import Input, Output, RunnableSerializable +from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.passthrough import RunnablePassthrough @@ -36,7 +36,7 @@ class PutLocalVar(RunnablePassthrough): def _concat_put( self, - input: Input, + input: Other, *, config: Optional[RunnableConfig] = None, replace: bool = False, @@ -68,35 +68,35 @@ class PutLocalVar(RunnablePassthrough): f"{(type(self.key))}." ) - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other: self._concat_put(input, config=config, replace=True) return super().invoke(input, config=config) async def ainvoke( self, - input: Input, + input: Other, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Input: + ) -> Other: self._concat_put(input, config=config, replace=True) return await super().ainvoke(input, config=config) def transform( self, - input: Iterator[Input], + input: Iterator[Other], config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Iterator[Input]: + ) -> Iterator[Other]: for chunk in super().transform(input, config=config): self._concat_put(chunk, config=config) yield chunk async def atransform( self, - input: AsyncIterator[Input], + input: AsyncIterator[Other], config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> AsyncIterator[Input]: + ) -> AsyncIterator[Other]: async for chunk in super().atransform(input, config=config): self._concat_put(chunk, config=config) yield chunk diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7e1f51554ad..c053d2f6257 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2296,6 +2296,25 @@ class RunnableBinding(RunnableSerializable[Input, Output]): class Config: arbitrary_types_allowed = True + def __init__( + self, + *, + bound: Runnable[Input, Output], + kwargs: Mapping[str, Any], + config: Optional[Mapping[str, Any]] = None, + **other_kwargs: Any, + ) -> None: + config = config or {} + if configurable := config.get("configurable", None): + allowed_keys = set(s.id for s in bound.config_specs) + for key in configurable: + if key not in allowed_keys: + raise ValueError( + f"Configurable key '{key}' not found in runnable with" + f" config keys: {allowed_keys}" + ) + super().__init__(bound=bound, kwargs=kwargs, config=config, **other_kwargs) + @property def InputType(self) -> Type[Input]: return self.bound.InputType diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 5dba5b9c8f4..6fb59c8e1c4 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -22,7 +22,7 @@ from typing import ( from langchain.pydantic_v1 import BaseModel, create_model from langchain.schema.runnable.base import ( - Input, + Other, Runnable, RunnableParallel, RunnableSerializable, @@ -33,17 +33,17 @@ from langchain.utils.aiter import atee, py_anext from langchain.utils.iter import safetee -def identity(x: Input) -> Input: +def identity(x: Other) -> Other: """An identity function""" return x -async def aidentity(x: Input) -> Input: +async def aidentity(x: Other) -> Other: """An async identity function""" return x -class RunnablePassthrough(RunnableSerializable[Input, Input]): +class RunnablePassthrough(RunnableSerializable[Other, Other]): """A runnable to passthrough inputs unchanged or with additional keys. This runnable behaves almost like the identity function, except that it @@ -100,20 +100,20 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): # {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20} """ - input_type: Optional[Type[Input]] = None + input_type: Optional[Type[Other]] = None - func: Optional[Callable[[Input], None]] = None + func: Optional[Callable[[Other], None]] = None - afunc: Optional[Callable[[Input], Awaitable[None]]] = None + afunc: Optional[Callable[[Other], Awaitable[None]]] = None def __init__( self, func: Optional[ - Union[Callable[[Input], None], Callable[[Input], Awaitable[None]]] + Union[Callable[[Other], None], Callable[[Other], Awaitable[None]]] ] = None, - afunc: Optional[Callable[[Input], Awaitable[None]]] = None, + afunc: Optional[Callable[[Other], Awaitable[None]]] = None, *, - input_type: Optional[Type[Input]] = None, + input_type: Optional[Type[Other]] = None, **kwargs: Any, ) -> None: if inspect.iscoroutinefunction(func): @@ -161,17 +161,17 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): """ return RunnableAssign(RunnableParallel(kwargs)) - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other: if self.func is not None: self.func(input) return self._call_with_config(identity, input, config) async def ainvoke( self, - input: Input, + input: Other, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Input: + ) -> Other: if self.afunc is not None: await self.afunc(input, **kwargs) elif self.func is not None: @@ -180,10 +180,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): def transform( self, - input: Iterator[Input], + input: Iterator[Other], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Iterator[Input]: + ) -> Iterator[Other]: if self.func is None: for chunk in self._transform_stream_with_config(input, identity, config): yield chunk @@ -202,10 +202,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): async def atransform( self, - input: AsyncIterator[Input], + input: AsyncIterator[Other], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> AsyncIterator[Input]: + ) -> AsyncIterator[Other]: if self.afunc is None and self.func is None: async for chunk in self._atransform_stream_with_config( input, identity, config @@ -231,19 +231,19 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): def stream( self, - input: Input, + input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Iterator[Input]: + ) -> Iterator[Other]: return self.transform(iter([input]), config, **kwargs) async def astream( self, - input: Input, + input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> AsyncIterator[Input]: - async def input_aiter() -> AsyncIterator[Input]: + ) -> AsyncIterator[Other]: + async def input_aiter() -> AsyncIterator[Other]: yield input async for chunk in self.atransform(input_aiter(), config, **kwargs): diff --git a/libs/langchain/langchain/schema/runnable/utils.py b/libs/langchain/langchain/schema/runnable/utils.py index cb1be68da4f..05ea6832dfa 100644 --- a/libs/langchain/langchain/schema/runnable/utils.py +++ b/libs/langchain/langchain/schema/runnable/utils.py @@ -24,9 +24,9 @@ from typing import ( Union, ) -Input = TypeVar("Input") +Input = TypeVar("Input", contravariant=True) # Output type should implement __concat__, as eg str, list, dict do -Output = TypeVar("Output") +Output = TypeVar("Output", covariant=True) async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 3f2189342da..a31e234ec69 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1002,6 +1002,9 @@ def test_configurable_fields_example() -> None: }, } + with pytest.raises(ValueError): + chain_configurable.with_config(configurable={"llm123": "chat"}) + assert ( chain_configurable.with_config(configurable={"llm": "chat"}).invoke( {"name": "John"}