From eb0134fbb3c728fb5c9180384276315f6318497b Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 9 Aug 2023 14:13:06 -0700 Subject: [PATCH] rfc --- libs/langchain/langchain/chat_models/base.py | 8 ++- libs/langchain/langchain/llms/base.py | 15 +++--- libs/langchain/langchain/schema/retriever.py | 8 ++- libs/langchain/langchain/schema/runnable.py | 57 +++++++++++++++++--- 4 files changed, 69 insertions(+), 19 deletions(-) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index b06b99f99d7..3d343274f5a 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -103,12 +103,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> BaseMessageChunk: + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) return cast( BaseMessageChunk, cast( ChatGeneration, self.generate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], stop=stop, **_config, **kwargs ).generations[0][0], ).message, ) @@ -127,8 +129,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): None, partial(self.invoke, input, config, stop=stop, **kwargs) ) + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], stop=stop, **_config, **kwargs ) return cast( BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 7da494de78b..04422124743 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -219,13 +219,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> str: - return ( - self.generate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs - ) - .generations[0][0] - .text + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) + result = self.generate_prompt( + [self._convert_input(input)], stop=stop, **_config, **kwargs ) + return result.generations[0][0].text async def ainvoke( self, @@ -241,8 +240,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): None, partial(self.invoke, input, config, stop=stop, **kwargs) ) + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], stop=stop, **_config, **kwargs ) return llm_result.generations[0][0].text diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 9df3e7a1389..538ae1ed1c1 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -107,7 +107,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): def invoke( self, input: str, config: Optional[RunnableConfig] = None ) -> List[Document]: - return self.get_relevant_documents(input, **(config or {})) + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) + return self.get_relevant_documents(input, **_config) async def ainvoke( self, input: str, config: Optional[RunnableConfig] = None @@ -116,7 +118,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): # If the retriever doesn't implement async, use default implementation return await super().ainvoke(input, config) - return await self.aget_relevant_documents(input, **(config or {})) + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) + return await self.aget_relevant_documents(input, **_config) @abstractmethod def _get_relevant_documents( diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index eebd5a96aa7..47679c8883c 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -674,6 +674,46 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): raise first_error +class PutLocalVar(Serializable, Runnable[Input, Input]): + key: Union[str, Dict[str, str]] + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if isinstance(self.key, str): + config["_locals"][self.key] = input + else: + if not isinstance(input, Mapping): + raise ValueError + for get_key, put_key in self.key.items(): + config["_locals"][put_key] = input[get_key] + return self._call_with_config(lambda x: x, input, config) + + +class GetLocalVar(Serializable, Runnable[str, Any]): + key: str + passthrough_key: Optional[str] = None + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if self.passthrough_key is not None: + return {self.key: config["_locals"][self.key], self.passthrough_key: input} + return config["_locals"][self.key] + + class RunnableSequence(Serializable, Runnable[Input, Output]): """ A sequence of runnables, where the output of each is the input of the next. @@ -749,11 +789,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): try: callbacks = run_manager.get_child() for step in self.steps: - input = step.invoke( - input, - # mark each step as a child run - _patch_config(config, callbacks), - ) + # mark each step as child run + step_config = _patch_config(config, callbacks) + input = step.invoke(input, step_config) # finish the root run except (KeyboardInterrupt, Exception) as e: run_manager.on_chain_error(e) @@ -1401,11 +1439,14 @@ class RouterRunnable( def _patch_config( - config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None + config: RunnableConfig, + callback_manager: BaseCallbackManager, + _locals: Optional[Dict[str, Any]] = None, ) -> RunnableConfig: - config = deepcopy(config) + config = config.copy() config["callbacks"] = callback_manager - config["_locals"] = _locals or {} + if _locals is not None: + config["_locals"] = _locals return config