mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
langchain[patch]: Use async memory in Chain when needed (#19429)
This commit is contained in:
parent
db7403d667
commit
63898dbda0
@ -181,7 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
inputs = self.prep_inputs(input)
|
||||
inputs = await self.aprep_inputs(input)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
@ -482,6 +482,30 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
inputs = dict(inputs, **external_context)
|
||||
return inputs
|
||||
|
||||
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
"""Prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of raw inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
|
||||
Returns:
|
||||
A dictionary of all inputs, including those added by the chain's memory.
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = await self.memory.aload_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
if len(self.output_keys) != 1:
|
||||
|
@ -1,5 +1,9 @@
|
||||
"""Test conversation chain and memory."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LLM
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
@ -10,6 +14,27 @@ from langchain.memory.summary import ConversationSummaryMemory
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class DummyLLM(LLM):
|
||||
last_prompt = ""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
self.last_prompt = prompt
|
||||
return "dummy"
|
||||
|
||||
|
||||
def test_memory_ai_prefix() -> None:
|
||||
"""Test that ai_prefix in the memory component works."""
|
||||
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
||||
@ -32,13 +57,18 @@ async def test_memory_async() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_conversation_chain_works() -> None:
|
||||
async def test_conversation_chain_works() -> None:
|
||||
"""Test that conversation chain works in basic setting."""
|
||||
llm = FakeLLM()
|
||||
llm = DummyLLM()
|
||||
prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}")
|
||||
memory = ConversationBufferMemory(memory_key="foo")
|
||||
chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar")
|
||||
chain.run("foo")
|
||||
chain.run("aaa")
|
||||
assert llm.last_prompt == " aaa"
|
||||
chain.run("bbb")
|
||||
assert llm.last_prompt == "Human: aaa\nAI: dummy bbb"
|
||||
await chain.arun("ccc")
|
||||
assert llm.last_prompt == "Human: aaa\nAI: dummy\nHuman: bbb\nAI: dummy ccc"
|
||||
|
||||
|
||||
def test_conversation_chain_errors_bad_prompt() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user