From 46f3850794f5fc14477d5545c6d1edd6bbfeca1a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:25:41 +0100 Subject: [PATCH] Lint --- .../langchain/schema/runnable/base.py | 1 - .../langchain/schema/runnable/config.py | 2 +- .../langchain/schema/runnable/locals.py | 78 +++++++++---------- 3 files changed, 39 insertions(+), 42 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c0caa6d9a20..1ca853174a2 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -52,7 +52,6 @@ from langchain.schema.runnable.utils import ( from langchain.utils.aiter import atee, py_anext from langchain.utils.iter import safetee - Input = TypeVar("Input") # Output type should implement __concat__, as eg str, list, dict do Output = TypeVar("Output") diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index cd620077e1f..716fc361161 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, TypedDict from langchain.callbacks.base import Callbacks -from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager +from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager class RunnableConfig(TypedDict, total=False): diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py index 5061dbf38c1..6d668059edf 100644 --- a/libs/langchain/langchain/schema/runnable/locals.py +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -2,6 +2,10 @@ from __future__ import annotations from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.load.serializable import Serializable from langchain.schema.runnable.base import Input, Output, Runnable from langchain.schema.runnable.config import RunnableConfig @@ -20,30 +24,12 @@ class PutLocalVar(RunnablePassthrough): def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None: super().__init__(key=key, **kwargs) - def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None: - if config is None: - raise ValueError( - "PutLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - if isinstance(self.key, str): - config["_locals"][self.key] = input - elif isinstance(self.key, Mapping): - if not isinstance(input, Mapping): - raise TypeError( - f"Received key of type Mapping but input of type {type(input)}. " - f"input is expected to be of type Mapping when key is Mapping." - ) - for input_key, put_key in self.key.items(): - config["_locals"][put_key] = input[input_key] - else: - raise TypeError( - f"`key` should be a string or Mapping[str, str], received type " - f"{(type(self.key))}." - ) - def _concat_put( - self, input: Input, *, config: Optional[RunnableConfig] = None + self, + input: Input, + *, + config: Optional[RunnableConfig] = None, + replace: bool = False, ) -> None: if config is None: raise ValueError( @@ -51,7 +37,7 @@ class PutLocalVar(RunnablePassthrough): "therefore always receive a non-null config." ) if isinstance(self.key, str): - if self.key not in config["_locals"]: + if self.key not in config["_locals"] or replace: config["_locals"][self.key] = input else: config["_locals"][self.key] += input @@ -62,7 +48,7 @@ class PutLocalVar(RunnablePassthrough): f"input is expected to be of type Mapping when key is Mapping." ) for input_key, put_key in self.key.items(): - if put_key not in config["_locals"]: + if put_key not in config["_locals"] or replace: config["_locals"][put_key] = input[input_key] else: config["_locals"][put_key] += input[input_key] @@ -73,24 +59,30 @@ class PutLocalVar(RunnablePassthrough): ) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: - self._put(input, config=config) + self._concat_put(input, config=config, replace=True) return super().invoke(input, config=config) async def ainvoke( - self, input: Input, config: RunnableConfig | None = None + self, input: Input, config: Optional[RunnableConfig] = None ) -> Input: - self._put(input, config=config) + self._concat_put(input, config=config, replace=True) return await super().ainvoke(input, config=config) def transform( - self, input: Iterator[Input], config: RunnableConfig | None = None + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Iterator[Input]: for chunk in super().transform(input, config=config): self._concat_put(chunk, config=config) yield chunk async def atransform( - self, input: AsyncIterator[Input], config: RunnableConfig | None = None + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Input]: async for chunk in super().atransform(input, config=config): self._concat_put(chunk, config=config) @@ -113,19 +105,27 @@ class GetLocalVar( def __init__(self, key: str, **kwargs: Any) -> None: super().__init__(key=key, **kwargs) - def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]: + def _get( + self, + input: Input, + run_manager: Union[CallbackManagerForChainRun, Any], + config: RunnableConfig, + ) -> Union[Output, Dict[str, Union[Input, Output]]]: if self.passthrough_key: return { - self.key: full_input["locals"][self.key], - self.passthrough_key: full_input["input"], + self.key: config["_locals"][self.key], + self.passthrough_key: input, } else: - return full_input["locals"][self.key] + return config["_locals"][self.key] async def _aget( - self, full_input: Dict + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, ) -> Union[Output, Dict[str, Union[Input, Output]]]: - return self._get(full_input) + return self._get(input, run_manager, config) def invoke( self, input: Input, config: Optional[RunnableConfig] = None @@ -136,8 +136,7 @@ class GetLocalVar( "therefore always receive a non-null config." ) - log_input = {"input": input, "locals": config["_locals"]} - return self._call_with_config(self._get, log_input, config) + return self._call_with_config(self._get, input, config) async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None @@ -148,5 +147,4 @@ class GetLocalVar( "therefore always receive a non-null config." ) - log_input = {"input": input, "locals": config["_locals"]} - return await self._acall_with_config(self._aget, log_input, config) + return await self._acall_with_config(self._aget, input, config)