mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
RFC: unpack .assign() arg
This commit is contained in:
@@ -452,6 +452,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def assign(
|
||||
self,
|
||||
*args: Union[
|
||||
Runnable[Dict[str, Any], Dict[str, Any]],
|
||||
Callable[[Dict[str, Any]], Dict[str, Any]],
|
||||
],
|
||||
**kwargs: Union[
|
||||
Runnable[Dict[str, Any], Any],
|
||||
Callable[[Dict[str, Any]], Any],
|
||||
@@ -465,7 +469,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Returns a new runnable."""
|
||||
from langchain_core.runnables.passthrough import RunnableAssign
|
||||
|
||||
return self | RunnableAssign(RunnableParallel(kwargs))
|
||||
if args and kwargs:
|
||||
raise ValueError
|
||||
elif args and len(args) > 1:
|
||||
raise ValueError
|
||||
elif args:
|
||||
return self | RunnableAssign(coerce_to_runnable(args[0]))
|
||||
else:
|
||||
return self | RunnableAssign(RunnableParallel(kwargs))
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@@ -1689,7 +1700,12 @@ def _seq_input_schema(
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in next_input_schema.__fields__.items()
|
||||
if k not in first.mapper.steps
|
||||
if k
|
||||
not in (
|
||||
first.mapper.steps
|
||||
if isinstance(first.mapper, RunnableParallel)
|
||||
else {}
|
||||
)
|
||||
},
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableParallel,
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
@@ -183,6 +184,10 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
@classmethod
|
||||
def assign(
|
||||
cls,
|
||||
*args: Union[
|
||||
Runnable[Dict[str, Any], Dict[str, Any]],
|
||||
Callable[[Dict[str, Any]], Dict[str, Any]],
|
||||
],
|
||||
**kwargs: Union[
|
||||
Runnable[Dict[str, Any], Any],
|
||||
Callable[[Dict[str, Any]], Any],
|
||||
@@ -201,7 +206,14 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
A runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
if args and kwargs:
|
||||
raise ValueError
|
||||
elif args and len(args) > 1:
|
||||
raise ValueError
|
||||
elif args:
|
||||
return RunnableAssign(coerce_to_runnable(args[0]))
|
||||
else:
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
|
||||
def invoke(
|
||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
@@ -313,9 +325,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||
"""
|
||||
|
||||
mapper: RunnableParallel[Dict[str, Any]]
|
||||
mapper: Runnable[Dict[str, Any], Dict[str, Any]]
|
||||
|
||||
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self, mapper: Runnable[Dict[str, Any], Dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(mapper=mapper, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -330,9 +344,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
def get_name(
|
||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||
) -> str:
|
||||
name = (
|
||||
name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>"
|
||||
)
|
||||
if isinstance(self.mapper, RunnableParallel):
|
||||
default_name = f"RunnableAssign<{','.join(self.mapper.steps.keys())}>"
|
||||
else:
|
||||
default_name = "RunnableAssign"
|
||||
name = name or self.name or default_name
|
||||
return super().get_name(suffix, name=name)
|
||||
|
||||
def get_input_schema(
|
||||
@@ -450,7 +466,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps.keys())
|
||||
mapper_keys = (
|
||||
set(self.mapper.steps.keys())
|
||||
if isinstance(self.mapper, RunnableParallel)
|
||||
else {}
|
||||
)
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
||||
|
||||
@@ -506,7 +526,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps.keys())
|
||||
mapper_keys = (
|
||||
set(self.mapper.steps.keys())
|
||||
if isinstance(self.mapper, RunnableParallel)
|
||||
else {}
|
||||
)
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
||||
# create map output stream
|
||||
|
||||
Reference in New Issue
Block a user