This commit is contained in:
Nuno Campos 2023-08-18 10:25:41 +01:00
parent 24a197f96a
commit 46f3850794
3 changed files with 39 additions and 42 deletions

View File

@ -52,7 +52,6 @@ from langchain.schema.runnable.utils import (
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee from langchain.utils.iter import safetee
Input = TypeVar("Input") Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do # Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output") Output = TypeVar("Output")

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, TypedDict from typing import Any, Dict, List, Optional, TypedDict
from langchain.callbacks.base import Callbacks from langchain.callbacks.base import Callbacks
from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):

View File

@ -2,6 +2,10 @@ from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig
@ -20,30 +24,12 @@ class PutLocalVar(RunnablePassthrough):
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None: def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
super().__init__(key=key, **kwargs) super().__init__(key=key, **kwargs)
def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
config["_locals"][self.key] = input
elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
config["_locals"][put_key] = input[input_key]
else:
raise TypeError(
f"`key` should be a string or Mapping[str, str], received type "
f"{(type(self.key))}."
)
def _concat_put( def _concat_put(
self, input: Input, *, config: Optional[RunnableConfig] = None self,
input: Input,
*,
config: Optional[RunnableConfig] = None,
replace: bool = False,
) -> None: ) -> None:
if config is None: if config is None:
raise ValueError( raise ValueError(
@ -51,7 +37,7 @@ class PutLocalVar(RunnablePassthrough):
"therefore always receive a non-null config." "therefore always receive a non-null config."
) )
if isinstance(self.key, str): if isinstance(self.key, str):
if self.key not in config["_locals"]: if self.key not in config["_locals"] or replace:
config["_locals"][self.key] = input config["_locals"][self.key] = input
else: else:
config["_locals"][self.key] += input config["_locals"][self.key] += input
@ -62,7 +48,7 @@ class PutLocalVar(RunnablePassthrough):
f"input is expected to be of type Mapping when key is Mapping." f"input is expected to be of type Mapping when key is Mapping."
) )
for input_key, put_key in self.key.items(): for input_key, put_key in self.key.items():
if put_key not in config["_locals"]: if put_key not in config["_locals"] or replace:
config["_locals"][put_key] = input[input_key] config["_locals"][put_key] = input[input_key]
else: else:
config["_locals"][put_key] += input[input_key] config["_locals"][put_key] += input[input_key]
@ -73,24 +59,30 @@ class PutLocalVar(RunnablePassthrough):
) )
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
self._put(input, config=config) self._concat_put(input, config=config, replace=True)
return super().invoke(input, config=config) return super().invoke(input, config=config)
async def ainvoke( async def ainvoke(
self, input: Input, config: RunnableConfig | None = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Input: ) -> Input:
self._put(input, config=config) self._concat_put(input, config=config, replace=True)
return await super().ainvoke(input, config=config) return await super().ainvoke(input, config=config)
def transform( def transform(
self, input: Iterator[Input], config: RunnableConfig | None = None self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Input]: ) -> Iterator[Input]:
for chunk in super().transform(input, config=config): for chunk in super().transform(input, config=config):
self._concat_put(chunk, config=config) self._concat_put(chunk, config=config)
yield chunk yield chunk
async def atransform( async def atransform(
self, input: AsyncIterator[Input], config: RunnableConfig | None = None self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Input]: ) -> AsyncIterator[Input]:
async for chunk in super().atransform(input, config=config): async for chunk in super().atransform(input, config=config):
self._concat_put(chunk, config=config) self._concat_put(chunk, config=config)
@ -113,19 +105,27 @@ class GetLocalVar(
def __init__(self, key: str, **kwargs: Any) -> None: def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs) super().__init__(key=key, **kwargs)
def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]: def _get(
self,
input: Input,
run_manager: Union[CallbackManagerForChainRun, Any],
config: RunnableConfig,
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if self.passthrough_key: if self.passthrough_key:
return { return {
self.key: full_input["locals"][self.key], self.key: config["_locals"][self.key],
self.passthrough_key: full_input["input"], self.passthrough_key: input,
} }
else: else:
return full_input["locals"][self.key] return config["_locals"][self.key]
async def _aget( async def _aget(
self, full_input: Dict self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Union[Output, Dict[str, Union[Input, Output]]]: ) -> Union[Output, Dict[str, Union[Input, Output]]]:
return self._get(full_input) return self._get(input, run_manager, config)
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
@ -136,8 +136,7 @@ class GetLocalVar(
"therefore always receive a non-null config." "therefore always receive a non-null config."
) )
log_input = {"input": input, "locals": config["_locals"]} return self._call_with_config(self._get, input, config)
return self._call_with_config(self._get, log_input, config)
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
@ -148,5 +147,4 @@ class GetLocalVar(
"therefore always receive a non-null config." "therefore always receive a non-null config."
) )
log_input = {"input": input, "locals": config["_locals"]} return await self._acall_with_config(self._aget, input, config)
return await self._acall_with_config(self._aget, log_input, config)