mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
langchain[patch]: Use async memory in Chain when needed (#19429)
This commit is contained in:
parent
db7403d667
commit
63898dbda0
@ -181,7 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
include_run_info = kwargs.get("include_run_info", False)
|
include_run_info = kwargs.get("include_run_info", False)
|
||||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||||
|
|
||||||
inputs = self.prep_inputs(input)
|
inputs = await self.aprep_inputs(input)
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
callbacks,
|
callbacks,
|
||||||
self.callbacks,
|
self.callbacks,
|
||||||
@ -482,6 +482,30 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
inputs = dict(inputs, **external_context)
|
inputs = dict(inputs, **external_context)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||||
|
"""Prepare chain inputs, including adding inputs from memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Dictionary of raw inputs, or single input if chain expects
|
||||||
|
only one param. Should contain all inputs specified in
|
||||||
|
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||||
|
memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of all inputs, including those added by the chain's memory.
|
||||||
|
"""
|
||||||
|
if not isinstance(inputs, dict):
|
||||||
|
_input_keys = set(self.input_keys)
|
||||||
|
if self.memory is not None:
|
||||||
|
# If there are multiple input keys, but some get set by memory so that
|
||||||
|
# only one is not set, we can still figure out which key it is.
|
||||||
|
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||||
|
inputs = {list(_input_keys)[0]: inputs}
|
||||||
|
if self.memory is not None:
|
||||||
|
external_context = await self.memory.aload_memory_variables(inputs)
|
||||||
|
inputs = dict(inputs, **external_context)
|
||||||
|
return inputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _run_output_key(self) -> str:
|
def _run_output_key(self) -> str:
|
||||||
if len(self.output_keys) != 1:
|
if len(self.output_keys) != 1:
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
"""Test conversation chain and memory."""
|
"""Test conversation chain and memory."""
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LLM
|
||||||
from langchain_core.memory import BaseMemory
|
from langchain_core.memory import BaseMemory
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
@ -10,6 +14,27 @@ from langchain.memory.summary import ConversationSummaryMemory
|
|||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
class DummyLLM(LLM):
|
||||||
|
last_prompt = ""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
self.last_prompt = prompt
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
|
||||||
def test_memory_ai_prefix() -> None:
|
def test_memory_ai_prefix() -> None:
|
||||||
"""Test that ai_prefix in the memory component works."""
|
"""Test that ai_prefix in the memory component works."""
|
||||||
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
||||||
@ -32,13 +57,18 @@ async def test_memory_async() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_conversation_chain_works() -> None:
|
async def test_conversation_chain_works() -> None:
|
||||||
"""Test that conversation chain works in basic setting."""
|
"""Test that conversation chain works in basic setting."""
|
||||||
llm = FakeLLM()
|
llm = DummyLLM()
|
||||||
prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}")
|
prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}")
|
||||||
memory = ConversationBufferMemory(memory_key="foo")
|
memory = ConversationBufferMemory(memory_key="foo")
|
||||||
chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar")
|
chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar")
|
||||||
chain.run("foo")
|
chain.run("aaa")
|
||||||
|
assert llm.last_prompt == " aaa"
|
||||||
|
chain.run("bbb")
|
||||||
|
assert llm.last_prompt == "Human: aaa\nAI: dummy bbb"
|
||||||
|
await chain.arun("ccc")
|
||||||
|
assert llm.last_prompt == "Human: aaa\nAI: dummy\nHuman: bbb\nAI: dummy ccc"
|
||||||
|
|
||||||
|
|
||||||
def test_conversation_chain_errors_bad_prompt() -> None:
|
def test_conversation_chain_errors_bad_prompt() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user