diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 280ecd25eb3..4846648ffaa 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -452,6 +452,10 @@ class Runnable(Generic[Input, Output], ABC): def assign( self, + *args: Union[ + Runnable[Dict[str, Any], Dict[str, Any]], + Callable[[Dict[str, Any]], Dict[str, Any]], + ], **kwargs: Union[ Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any], @@ -465,7 +469,14 @@ class Runnable(Generic[Input, Output], ABC): Returns a new runnable.""" from langchain_core.runnables.passthrough import RunnableAssign - return self | RunnableAssign(RunnableParallel(kwargs)) + if args and kwargs: + raise ValueError + elif args and len(args) > 1: + raise ValueError + elif args: + return self | RunnableAssign(coerce_to_runnable(args[0])) + else: + return self | RunnableAssign(RunnableParallel(kwargs)) """ --- Public API --- """ @@ -1689,7 +1700,12 @@ def _seq_input_schema( **{ k: (v.annotation, v.default) for k, v in next_input_schema.__fields__.items() - if k not in first.mapper.steps + if k + not in ( + first.mapper.steps + if isinstance(first.mapper, RunnableParallel) + else {} + ) }, __config__=_SchemaConfig, ) diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 1f4367d9d5e..6590234731c 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -26,6 +26,7 @@ from langchain_core.runnables.base import ( Runnable, RunnableParallel, RunnableSerializable, + coerce_to_runnable, ) from langchain_core.runnables.config import ( RunnableConfig, @@ -183,6 +184,10 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): @classmethod def assign( cls, + *args: Union[ + Runnable[Dict[str, Any], Dict[str, Any]], + Callable[[Dict[str, Any]], Dict[str, Any]], + ], **kwargs: Union[ Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any], @@ -201,7 +206,14 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): A runnable that merges the Dict input with the output produced by the mapping argument. """ - return RunnableAssign(RunnableParallel(kwargs)) + if args and kwargs: + raise ValueError + elif args and len(args) > 1: + raise ValueError + elif args: + return RunnableAssign(coerce_to_runnable(args[0])) + else: + return RunnableAssign(RunnableParallel(kwargs)) def invoke( self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -313,9 +325,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): A runnable that assigns key-value pairs to Dict[str, Any] inputs. """ - mapper: RunnableParallel[Dict[str, Any]] + mapper: Runnable[Dict[str, Any], Dict[str, Any]] - def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None: + def __init__( + self, mapper: Runnable[Dict[str, Any], Dict[str, Any]], **kwargs: Any + ) -> None: super().__init__(mapper=mapper, **kwargs) @classmethod @@ -330,9 +344,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: - name = ( - name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>" - ) + if isinstance(self.mapper, RunnableParallel): + default_name = f"RunnableAssign<{','.join(self.mapper.steps.keys())}>" + else: + default_name = "RunnableAssign" + name = name or self.name or default_name return super().get_name(suffix, name=name) def get_input_schema( @@ -450,7 +466,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): **kwargs: Any, ) -> Iterator[Dict[str, Any]]: # collect mapper keys - mapper_keys = set(self.mapper.steps.keys()) + mapper_keys = ( + set(self.mapper.steps.keys()) + if isinstance(self.mapper, RunnableParallel) + else {} + ) # create two streams, one for the map and one for the passthrough for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) @@ -506,7 +526,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): **kwargs: Any, ) -> AsyncIterator[Dict[str, Any]]: # collect mapper keys - mapper_keys = set(self.mapper.steps.keys()) + mapper_keys = ( + set(self.mapper.steps.keys()) + if isinstance(self.mapper, RunnableParallel) + else {} + ) # create two streams, one for the map and one for the passthrough for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) # create map output stream