From 63898dbda0f8e8b46703577f5b56a7d002878aad Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 25 Mar 2024 07:49:00 +0100 Subject: [PATCH] langchain[patch]: Use async memory in Chain when needed (#19429) --- libs/langchain/langchain/chains/base.py | 26 +++++++++++++- .../unit_tests/chains/test_conversation.py | 36 +++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 2f15d1fcc87..aab2891b67e 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -181,7 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): include_run_info = kwargs.get("include_run_info", False) return_only_outputs = kwargs.get("return_only_outputs", False) - inputs = self.prep_inputs(input) + inputs = await self.aprep_inputs(input) callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, @@ -482,6 +482,30 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): inputs = dict(inputs, **external_context) return inputs + async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: + """Prepare chain inputs, including adding inputs from memory. + + Args: + inputs: Dictionary of raw inputs, or single input if chain expects + only one param. Should contain all inputs specified in + `Chain.input_keys` except for inputs that will be set by the chain's + memory. + + Returns: + A dictionary of all inputs, including those added by the chain's memory. + """ + if not isinstance(inputs, dict): + _input_keys = set(self.input_keys) + if self.memory is not None: + # If there are multiple input keys, but some get set by memory so that + # only one is not set, we can still figure out which key it is. + _input_keys = _input_keys.difference(self.memory.memory_variables) + inputs = {list(_input_keys)[0]: inputs} + if self.memory is not None: + external_context = await self.memory.aload_memory_variables(inputs) + inputs = dict(inputs, **external_context) + return inputs + @property def _run_output_key(self) -> str: if len(self.output_keys) != 1: diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation.py b/libs/langchain/tests/unit_tests/chains/test_conversation.py index d00b1e4bc6e..23f06748f99 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation.py @@ -1,5 +1,9 @@ """Test conversation chain and memory.""" +from typing import Any, List, Optional + import pytest +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LLM from langchain_core.memory import BaseMemory from langchain_core.prompts.prompt import PromptTemplate @@ -10,6 +14,27 @@ from langchain.memory.summary import ConversationSummaryMemory from tests.unit_tests.llms.fake_llm import FakeLLM +class DummyLLM(LLM): + last_prompt = "" + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + @property + def _llm_type(self) -> str: + return "dummy" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + self.last_prompt = prompt + return "dummy" + + def test_memory_ai_prefix() -> None: """Test that ai_prefix in the memory component works.""" memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant") @@ -32,13 +57,18 @@ async def test_memory_async() -> None: } -def test_conversation_chain_works() -> None: +async def test_conversation_chain_works() -> None: """Test that conversation chain works in basic setting.""" - llm = FakeLLM() + llm = DummyLLM() prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}") memory = ConversationBufferMemory(memory_key="foo") chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar") - chain.run("foo") + chain.run("aaa") + assert llm.last_prompt == " aaa" + chain.run("bbb") + assert llm.last_prompt == "Human: aaa\nAI: dummy bbb" + await chain.arun("ccc") + assert llm.last_prompt == "Human: aaa\nAI: dummy\nHuman: bbb\nAI: dummy ccc" def test_conversation_chain_errors_bad_prompt() -> None: