langchain[patch]: Use async memory in Chain when needed (#19429)

This commit is contained in:
Christophe Bornet 2024-03-25 07:49:00 +01:00 committed by GitHub
parent db7403d667
commit 63898dbda0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 4 deletions

View File

@ -181,7 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
include_run_info = kwargs.get("include_run_info", False) include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False) return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = self.prep_inputs(input) inputs = await self.aprep_inputs(input)
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, callbacks,
self.callbacks, self.callbacks,
@ -482,6 +482,30 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
inputs = dict(inputs, **external_context) inputs = dict(inputs, **external_context)
return inputs return inputs
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory.
Args:
inputs: Dictionary of raw inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
Returns:
A dictionary of all inputs, including those added by the chain's memory.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = await self.memory.aload_memory_variables(inputs)
inputs = dict(inputs, **external_context)
return inputs
@property @property
def _run_output_key(self) -> str: def _run_output_key(self) -> str:
if len(self.output_keys) != 1: if len(self.output_keys) != 1:

View File

@ -1,5 +1,9 @@
"""Test conversation chain and memory.""" """Test conversation chain and memory."""
from typing import Any, List, Optional
import pytest import pytest
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM
from langchain_core.memory import BaseMemory from langchain_core.memory import BaseMemory
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
@ -10,6 +14,27 @@ from langchain.memory.summary import ConversationSummaryMemory
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
class DummyLLM(LLM):
last_prompt = ""
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
@property
def _llm_type(self) -> str:
return "dummy"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
self.last_prompt = prompt
return "dummy"
def test_memory_ai_prefix() -> None: def test_memory_ai_prefix() -> None:
"""Test that ai_prefix in the memory component works.""" """Test that ai_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant") memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
@ -32,13 +57,18 @@ async def test_memory_async() -> None:
} }
def test_conversation_chain_works() -> None: async def test_conversation_chain_works() -> None:
"""Test that conversation chain works in basic setting.""" """Test that conversation chain works in basic setting."""
llm = FakeLLM() llm = DummyLLM()
prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}") prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}")
memory = ConversationBufferMemory(memory_key="foo") memory = ConversationBufferMemory(memory_key="foo")
chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar") chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar")
chain.run("foo") chain.run("aaa")
assert llm.last_prompt == " aaa"
chain.run("bbb")
assert llm.last_prompt == "Human: aaa\nAI: dummy bbb"
await chain.arun("ccc")
assert llm.last_prompt == "Human: aaa\nAI: dummy\nHuman: bbb\nAI: dummy ccc"
def test_conversation_chain_errors_bad_prompt() -> None: def test_conversation_chain_errors_bad_prompt() -> None: