Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
fbeef3fc75 core[minor]: Add Runnable.drop() 2024-01-19 16:57:54 -08:00
2 changed files with 151 additions and 3 deletions

View File

@@ -450,6 +450,13 @@ class Runnable(Generic[Input, Output], ABC):
return self | RunnablePick(keys)
def drop(self, keys: Union[str, List[str]]) -> RunnableSerializable[Any, Any]:
"""Drop keys from the dict output of this runnable.
Returns a new runnable."""
from langchain_core.runnables.passthrough import RunnableDrop
return self | RunnableDrop(keys)
def assign(
self,
**kwargs: Union[
@@ -1667,7 +1674,11 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
def _seq_input_schema(
steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> Type[BaseModel]:
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
from langchain_core.runnables.passthrough import (
RunnableAssign,
RunnableDrop,
RunnablePick,
)
first = steps[0]
if len(steps) == 1:
@@ -1685,7 +1696,7 @@ def _seq_input_schema(
},
__config__=_SchemaConfig,
)
elif isinstance(first, RunnablePick):
elif isinstance(first, (RunnablePick, RunnableDrop)):
return _seq_input_schema(steps[1:], config)
return first.get_input_schema(config)
@@ -1694,7 +1705,11 @@ def _seq_input_schema(
def _seq_output_schema(
steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> Type[BaseModel]:
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
from langchain_core.runnables.passthrough import (
RunnableAssign,
RunnableDrop,
RunnablePick,
)
last = steps[-1]
if len(steps) == 1:
@@ -1739,6 +1754,19 @@ def _seq_output_schema(
__root__=(field.annotation, field.default),
__config__=_SchemaConfig,
)
elif isinstance(last, RunnableDrop):
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
if k not in last.keys
},
__config__=_SchemaConfig,
)
return last.get_output_schema(config)

View File

@@ -15,6 +15,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
@@ -699,3 +700,122 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk
class RunnableDrop(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that drops keys from Dict[str, Any] inputs.
"""
keys: List[str]
def __init__(self, keys: Union[str, Sequence[str]], **kwargs: Any) -> None:
keys = [keys] if isinstance(keys, str) else list(keys)
super().__init__(keys=keys, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = name or self.name or f"RunnableDrop<{','.join(self.keys )}>"
return super().get_name(suffix, name=name)
def _drop(self, input: Dict[str, Any]) -> Any:
assert isinstance(input, dict), "The input to RunnableDrop must be a dict."
picked = {k: v for k, v in input.items() if k not in self.keys}
return AddableDict(picked) if picked else None
def _invoke(
self,
input: Dict[str, Any],
) -> Dict[str, Any]:
return self._drop(input)
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Dict[str, Any],
) -> Dict[str, Any]:
return self._drop(input)
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _transform(
self,
input: Iterator[Dict[str, Any]],
) -> Iterator[Dict[str, Any]]:
for chunk in input:
picked = self._drop(chunk)
if picked is not None:
yield picked
def transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
async def _atransform(
self,
input: AsyncIterator[Dict[str, Any]],
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in input:
picked = self._drop(chunk)
if picked is not None:
yield picked
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
):
yield chunk
def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)
async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk