mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
Lint
This commit is contained in:
parent
24a197f96a
commit
46f3850794
@ -52,7 +52,6 @@ from langchain.schema.runnable.utils import (
|
|||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
from langchain.utils.iter import safetee
|
from langchain.utils.iter import safetee
|
||||||
|
|
||||||
|
|
||||||
Input = TypeVar("Input")
|
Input = TypeVar("Input")
|
||||||
# Output type should implement __concat__, as eg str, list, dict do
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
Output = TypeVar("Output")
|
Output = TypeVar("Output")
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from typing import Any, Dict, List, Optional, TypedDict
|
from typing import Any, Dict, List, Optional, TypedDict
|
||||||
|
|
||||||
from langchain.callbacks.base import Callbacks
|
from langchain.callbacks.base import Callbacks
|
||||||
from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager
|
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
|
|
||||||
|
|
||||||
class RunnableConfig(TypedDict, total=False):
|
class RunnableConfig(TypedDict, total=False):
|
||||||
|
@ -2,6 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
|
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.runnable.base import Input, Output, Runnable
|
from langchain.schema.runnable.base import Input, Output, Runnable
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
@ -20,30 +24,12 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
|
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
|
||||||
super().__init__(key=key, **kwargs)
|
super().__init__(key=key, **kwargs)
|
||||||
|
|
||||||
def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None:
|
|
||||||
if config is None:
|
|
||||||
raise ValueError(
|
|
||||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
|
||||||
"therefore always receive a non-null config."
|
|
||||||
)
|
|
||||||
if isinstance(self.key, str):
|
|
||||||
config["_locals"][self.key] = input
|
|
||||||
elif isinstance(self.key, Mapping):
|
|
||||||
if not isinstance(input, Mapping):
|
|
||||||
raise TypeError(
|
|
||||||
f"Received key of type Mapping but input of type {type(input)}. "
|
|
||||||
f"input is expected to be of type Mapping when key is Mapping."
|
|
||||||
)
|
|
||||||
for input_key, put_key in self.key.items():
|
|
||||||
config["_locals"][put_key] = input[input_key]
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"`key` should be a string or Mapping[str, str], received type "
|
|
||||||
f"{(type(self.key))}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _concat_put(
|
def _concat_put(
|
||||||
self, input: Input, *, config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
input: Input,
|
||||||
|
*,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
replace: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -51,7 +37,7 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
"therefore always receive a non-null config."
|
"therefore always receive a non-null config."
|
||||||
)
|
)
|
||||||
if isinstance(self.key, str):
|
if isinstance(self.key, str):
|
||||||
if self.key not in config["_locals"]:
|
if self.key not in config["_locals"] or replace:
|
||||||
config["_locals"][self.key] = input
|
config["_locals"][self.key] = input
|
||||||
else:
|
else:
|
||||||
config["_locals"][self.key] += input
|
config["_locals"][self.key] += input
|
||||||
@ -62,7 +48,7 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
f"input is expected to be of type Mapping when key is Mapping."
|
f"input is expected to be of type Mapping when key is Mapping."
|
||||||
)
|
)
|
||||||
for input_key, put_key in self.key.items():
|
for input_key, put_key in self.key.items():
|
||||||
if put_key not in config["_locals"]:
|
if put_key not in config["_locals"] or replace:
|
||||||
config["_locals"][put_key] = input[input_key]
|
config["_locals"][put_key] = input[input_key]
|
||||||
else:
|
else:
|
||||||
config["_locals"][put_key] += input[input_key]
|
config["_locals"][put_key] += input[input_key]
|
||||||
@ -73,24 +59,30 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||||
self._put(input, config=config)
|
self._concat_put(input, config=config, replace=True)
|
||||||
return super().invoke(input, config=config)
|
return super().invoke(input, config=config)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: RunnableConfig | None = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Input:
|
) -> Input:
|
||||||
self._put(input, config=config)
|
self._concat_put(input, config=config, replace=True)
|
||||||
return await super().ainvoke(input, config=config)
|
return await super().ainvoke(input, config=config)
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self, input: Iterator[Input], config: RunnableConfig | None = None
|
self,
|
||||||
|
input: Iterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Input]:
|
) -> Iterator[Input]:
|
||||||
for chunk in super().transform(input, config=config):
|
for chunk in super().transform(input, config=config):
|
||||||
self._concat_put(chunk, config=config)
|
self._concat_put(chunk, config=config)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Input]:
|
) -> AsyncIterator[Input]:
|
||||||
async for chunk in super().atransform(input, config=config):
|
async for chunk in super().atransform(input, config=config):
|
||||||
self._concat_put(chunk, config=config)
|
self._concat_put(chunk, config=config)
|
||||||
@ -113,19 +105,27 @@ class GetLocalVar(
|
|||||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||||
super().__init__(key=key, **kwargs)
|
super().__init__(key=key, **kwargs)
|
||||||
|
|
||||||
def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
def _get(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: Union[CallbackManagerForChainRun, Any],
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||||
if self.passthrough_key:
|
if self.passthrough_key:
|
||||||
return {
|
return {
|
||||||
self.key: full_input["locals"][self.key],
|
self.key: config["_locals"][self.key],
|
||||||
self.passthrough_key: full_input["input"],
|
self.passthrough_key: input,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return full_input["locals"][self.key]
|
return config["_locals"][self.key]
|
||||||
|
|
||||||
async def _aget(
|
async def _aget(
|
||||||
self, full_input: Dict
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||||
return self._get(full_input)
|
return self._get(input, run_manager, config)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
@ -136,8 +136,7 @@ class GetLocalVar(
|
|||||||
"therefore always receive a non-null config."
|
"therefore always receive a non-null config."
|
||||||
)
|
)
|
||||||
|
|
||||||
log_input = {"input": input, "locals": config["_locals"]}
|
return self._call_with_config(self._get, input, config)
|
||||||
return self._call_with_config(self._get, log_input, config)
|
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
@ -148,5 +147,4 @@ class GetLocalVar(
|
|||||||
"therefore always receive a non-null config."
|
"therefore always receive a non-null config."
|
||||||
)
|
)
|
||||||
|
|
||||||
log_input = {"input": input, "locals": config["_locals"]}
|
return await self._acall_with_config(self._aget, input, config)
|
||||||
return await self._acall_with_config(self._aget, log_input, config)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user