This commit is contained in:
Nuno Campos 2023-08-18 10:25:41 +01:00
parent 24a197f96a
commit 46f3850794
3 changed files with 39 additions and 42 deletions

View File

@ -52,7 +52,6 @@ from langchain.schema.runnable.utils import (
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, TypedDict
from langchain.callbacks.base import Callbacks
from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
class RunnableConfig(TypedDict, total=False):

View File

@ -2,6 +2,10 @@ from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.config import RunnableConfig
@ -20,30 +24,12 @@ class PutLocalVar(RunnablePassthrough):
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
config["_locals"][self.key] = input
elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
config["_locals"][put_key] = input[input_key]
else:
raise TypeError(
f"`key` should be a string or Mapping[str, str], received type "
f"{(type(self.key))}."
)
def _concat_put(
self, input: Input, *, config: Optional[RunnableConfig] = None
self,
input: Input,
*,
config: Optional[RunnableConfig] = None,
replace: bool = False,
) -> None:
if config is None:
raise ValueError(
@ -51,7 +37,7 @@ class PutLocalVar(RunnablePassthrough):
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
if self.key not in config["_locals"]:
if self.key not in config["_locals"] or replace:
config["_locals"][self.key] = input
else:
config["_locals"][self.key] += input
@ -62,7 +48,7 @@ class PutLocalVar(RunnablePassthrough):
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
if put_key not in config["_locals"]:
if put_key not in config["_locals"] or replace:
config["_locals"][put_key] = input[input_key]
else:
config["_locals"][put_key] += input[input_key]
@ -73,24 +59,30 @@ class PutLocalVar(RunnablePassthrough):
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
self._put(input, config=config)
self._concat_put(input, config=config, replace=True)
return super().invoke(input, config=config)
async def ainvoke(
self, input: Input, config: RunnableConfig | None = None
self, input: Input, config: Optional[RunnableConfig] = None
) -> Input:
self._put(input, config=config)
self._concat_put(input, config=config, replace=True)
return await super().ainvoke(input, config=config)
def transform(
self, input: Iterator[Input], config: RunnableConfig | None = None
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Input]:
for chunk in super().transform(input, config=config):
self._concat_put(chunk, config=config)
yield chunk
async def atransform(
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Input]:
async for chunk in super().atransform(input, config=config):
self._concat_put(chunk, config=config)
@ -113,19 +105,27 @@ class GetLocalVar(
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]:
def _get(
self,
input: Input,
run_manager: Union[CallbackManagerForChainRun, Any],
config: RunnableConfig,
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if self.passthrough_key:
return {
self.key: full_input["locals"][self.key],
self.passthrough_key: full_input["input"],
self.key: config["_locals"][self.key],
self.passthrough_key: input,
}
else:
return full_input["locals"][self.key]
return config["_locals"][self.key]
async def _aget(
self, full_input: Dict
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Union[Output, Dict[str, Union[Input, Output]]]:
return self._get(full_input)
return self._get(input, run_manager, config)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
@ -136,8 +136,7 @@ class GetLocalVar(
"therefore always receive a non-null config."
)
log_input = {"input": input, "locals": config["_locals"]}
return self._call_with_config(self._get, log_input, config)
return self._call_with_config(self._get, input, config)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
@ -148,5 +147,4 @@ class GetLocalVar(
"therefore always receive a non-null config."
)
log_input = {"input": input, "locals": config["_locals"]}
return await self._acall_with_config(self._aget, log_input, config)
return await self._acall_with_config(self._aget, input, config)