Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
1610d7c3b0 fmt 2024-01-24 12:41:03 -08:00
Bagatur
a9450d88e1 RFC: unpack .assign() arg 2024-01-24 12:36:51 -08:00
2 changed files with 52 additions and 10 deletions

View File

@@ -452,6 +452,10 @@ class Runnable(Generic[Input, Output], ABC):
def assign(
self,
*args: Union[
Runnable[Dict[str, Any], Dict[str, Any]],
Callable[[Dict[str, Any]], Dict[str, Any]],
],
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
@@ -465,7 +469,14 @@ class Runnable(Generic[Input, Output], ABC):
Returns a new runnable."""
from langchain_core.runnables.passthrough import RunnableAssign
return self | RunnableAssign(RunnableParallel(kwargs))
if args and kwargs:
raise ValueError
elif args and len(args) > 1:
raise ValueError
elif args:
return self | RunnableAssign(coerce_to_runnable(args[0]))
else:
return self | RunnableAssign(RunnableParallel(kwargs))
""" --- Public API --- """
@@ -1689,7 +1700,12 @@ def _seq_input_schema(
**{
k: (v.annotation, v.default)
for k, v in next_input_schema.__fields__.items()
if k not in first.mapper.steps
if k
not in (
first.mapper.steps
if isinstance(first.mapper, RunnableParallel)
else {}
)
},
__config__=_SchemaConfig,
)

View File

@@ -26,6 +26,7 @@ from langchain_core.runnables.base import (
Runnable,
RunnableParallel,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import (
RunnableConfig,
@@ -183,6 +184,10 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
@classmethod
def assign(
cls,
*args: Union[
Runnable[Dict[str, Any], Dict[str, Any]],
Callable[[Dict[str, Any]], Dict[str, Any]],
],
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
@@ -201,7 +206,14 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
A runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableParallel(kwargs))
if args and kwargs:
raise ValueError
elif args and len(args) > 1:
raise ValueError
elif args:
return RunnableAssign(coerce_to_runnable(args[0]))
else:
return RunnableAssign(RunnableParallel(kwargs))
def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
@@ -313,9 +325,13 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
"""
mapper: RunnableParallel[Dict[str, Any]]
mapper: RunnableSerializable[Dict[str, Any], Dict[str, Any]]
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
def __init__(
self,
mapper: RunnableSerializable[Dict[str, Any], Dict[str, Any]],
**kwargs: Any,
) -> None:
super().__init__(mapper=mapper, **kwargs)
@classmethod
@@ -330,9 +346,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = (
name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>"
)
if isinstance(self.mapper, RunnableParallel):
default_name = f"RunnableAssign<{','.join(self.mapper.steps.keys())}>"
else:
default_name = "RunnableAssign"
name = name or self.name or default_name
return super().get_name(suffix, name=name)
def get_input_schema(
@@ -450,7 +468,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
mapper_keys = (
set(self.mapper.steps.keys())
if isinstance(self.mapper, RunnableParallel)
else {}
)
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
@@ -506,7 +528,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
mapper_keys = (
set(self.mapper.steps.keys())
if isinstance(self.mapper, RunnableParallel)
else {}
)
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream