mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 21:35:08 +00:00
Lint
This commit is contained in:
parent
24a197f96a
commit
46f3850794
@ -52,7 +52,6 @@ from langchain.schema.runnable.utils import (
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output")
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
from langchain.callbacks.base import Callbacks
|
||||
from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager
|
||||
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
|
@ -2,6 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import Input, Output, Runnable
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
@ -20,30 +24,12 @@ class PutLocalVar(RunnablePassthrough):
|
||||
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
if isinstance(self.key, str):
|
||||
config["_locals"][self.key] = input
|
||||
elif isinstance(self.key, Mapping):
|
||||
if not isinstance(input, Mapping):
|
||||
raise TypeError(
|
||||
f"Received key of type Mapping but input of type {type(input)}. "
|
||||
f"input is expected to be of type Mapping when key is Mapping."
|
||||
)
|
||||
for input_key, put_key in self.key.items():
|
||||
config["_locals"][put_key] = input[input_key]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"`key` should be a string or Mapping[str, str], received type "
|
||||
f"{(type(self.key))}."
|
||||
)
|
||||
|
||||
def _concat_put(
|
||||
self, input: Input, *, config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Input,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
replace: bool = False,
|
||||
) -> None:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
@ -51,7 +37,7 @@ class PutLocalVar(RunnablePassthrough):
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
if isinstance(self.key, str):
|
||||
if self.key not in config["_locals"]:
|
||||
if self.key not in config["_locals"] or replace:
|
||||
config["_locals"][self.key] = input
|
||||
else:
|
||||
config["_locals"][self.key] += input
|
||||
@ -62,7 +48,7 @@ class PutLocalVar(RunnablePassthrough):
|
||||
f"input is expected to be of type Mapping when key is Mapping."
|
||||
)
|
||||
for input_key, put_key in self.key.items():
|
||||
if put_key not in config["_locals"]:
|
||||
if put_key not in config["_locals"] or replace:
|
||||
config["_locals"][put_key] = input[input_key]
|
||||
else:
|
||||
config["_locals"][put_key] += input[input_key]
|
||||
@ -73,24 +59,30 @@ class PutLocalVar(RunnablePassthrough):
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
self._put(input, config=config)
|
||||
self._concat_put(input, config=config, replace=True)
|
||||
return super().invoke(input, config=config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: RunnableConfig | None = None
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Input:
|
||||
self._put(input, config=config)
|
||||
self._concat_put(input, config=config, replace=True)
|
||||
return await super().ainvoke(input, config=config)
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: RunnableConfig | None = None
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Input]:
|
||||
for chunk in super().transform(input, config=config):
|
||||
self._concat_put(chunk, config=config)
|
||||
yield chunk
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Input]:
|
||||
async for chunk in super().atransform(input, config=config):
|
||||
self._concat_put(chunk, config=config)
|
||||
@ -113,19 +105,27 @@ class GetLocalVar(
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||
def _get(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: Union[CallbackManagerForChainRun, Any],
|
||||
config: RunnableConfig,
|
||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||
if self.passthrough_key:
|
||||
return {
|
||||
self.key: full_input["locals"][self.key],
|
||||
self.passthrough_key: full_input["input"],
|
||||
self.key: config["_locals"][self.key],
|
||||
self.passthrough_key: input,
|
||||
}
|
||||
else:
|
||||
return full_input["locals"][self.key]
|
||||
return config["_locals"][self.key]
|
||||
|
||||
async def _aget(
|
||||
self, full_input: Dict
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||
return self._get(full_input)
|
||||
return self._get(input, run_manager, config)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
@ -136,8 +136,7 @@ class GetLocalVar(
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
|
||||
log_input = {"input": input, "locals": config["_locals"]}
|
||||
return self._call_with_config(self._get, log_input, config)
|
||||
return self._call_with_config(self._get, input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
@ -148,5 +147,4 @@ class GetLocalVar(
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
|
||||
log_input = {"input": input, "locals": config["_locals"]}
|
||||
return await self._acall_with_config(self._aget, log_input, config)
|
||||
return await self._acall_with_config(self._aget, input, config)
|
||||
|
Loading…
Reference in New Issue
Block a user