mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-22 20:59:05 +00:00
Compare commits
5 Commits
open-swe/5
...
wfh/memory
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60313c3f62 | ||
|
|
b9bd33ff30 | ||
|
|
574d752c7e | ||
|
|
914489a0a1 | ||
|
|
506da41b97 |
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
from pydantic import Field
|
||||
@@ -19,10 +20,13 @@ class AutoGPTMemory(BaseChatMemory):
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
docs = self.retriever.get_relevant_documents(query, callbacks=callbacks)
|
||||
return {
|
||||
"chat_history": self.chat_memory.messages[-10:],
|
||||
"relevant_context": docs,
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
@@ -222,14 +223,21 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
return result
|
||||
|
||||
def fetch_memories(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
self,
|
||||
observation: str,
|
||||
now: Optional[datetime] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> List[Document]:
|
||||
"""Fetch related memories."""
|
||||
if now is not None:
|
||||
with mock_now(now):
|
||||
return self.memory_retriever.get_relevant_documents(observation)
|
||||
return self.memory_retriever.get_relevant_documents(
|
||||
observation, callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
return self.memory_retriever.get_relevant_documents(observation)
|
||||
return self.memory_retriever.get_relevant_documents(
|
||||
observation, callbacks=callbacks
|
||||
)
|
||||
|
||||
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
|
||||
content = []
|
||||
@@ -260,7 +268,9 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
"""Input keys this memory class will load dynamically."""
|
||||
return []
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
queries = inputs.get(self.queries_key)
|
||||
now = inputs.get(self.now_key)
|
||||
|
||||
@@ -95,7 +95,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, inputs: Any) -> None:
|
||||
self._inputs = self.agent_executor.prep_inputs(inputs)
|
||||
self._inputs = self.agent_executor.prep_inputs(inputs, callbacks=None)
|
||||
|
||||
@property
|
||||
def callbacks(self) -> Callbacks:
|
||||
|
||||
@@ -217,7 +217,6 @@ class Chain(Serializable, ABC):
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
@@ -232,6 +231,7 @@ class Chain(Serializable, ABC):
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
inputs = self.prep_inputs(inputs, callbacks=run_manager.get_child())
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
@@ -284,7 +284,6 @@ class Chain(Serializable, ABC):
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
@@ -299,6 +298,7 @@ class Chain(Serializable, ABC):
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
inputs = self.prep_inputs(inputs, callbacks=run_manager.get_child())
|
||||
try:
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
@@ -342,7 +342,9 @@ class Chain(Serializable, ABC):
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
def prep_inputs(
|
||||
self, inputs: Union[Dict[str, Any], Any], *, callbacks: Callbacks
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
@@ -369,7 +371,9 @@ class Chain(Serializable, ABC):
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
external_context = self.memory.load_memory_variables(
|
||||
inputs, callbacks=callbacks
|
||||
)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
@@ -262,11 +262,13 @@ The following is the expected answer. Use this to measure correctness:
|
||||
return ["score", "reasoning"]
|
||||
return ["score"]
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
def prep_inputs(
|
||||
self, inputs: Union[Dict[str, Any], Any], *, callbacks: Callbacks
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if "reference" not in inputs:
|
||||
inputs["reference"] = self._format_reference(inputs.get("reference"))
|
||||
return super().prep_inputs(inputs)
|
||||
return super().prep_inputs(inputs, callbacks=callbacks)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
@@ -20,7 +21,9 @@ class AutoGPTMemory(BaseChatMemory):
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
from langchain.schema import BaseMemory, Document
|
||||
@@ -260,7 +261,9 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
"""Input keys this memory class will load dynamically."""
|
||||
return []
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, str]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
queries = inputs.get(self.queries_key)
|
||||
now = inputs.get(self.now_key)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.schema.messages import get_buffer_string
|
||||
@@ -34,7 +35,9 @@ class ConversationBufferMemory(BaseChatMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
@@ -66,7 +69,9 @@ class ConversationStringBufferMemory(BaseMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
|
||||
@@ -26,7 +27,9 @@ class ConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
|
||||
buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else []
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Set
|
||||
|
||||
from pydantic import validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import BaseMemory
|
||||
|
||||
@@ -54,13 +55,15 @@ class CombinedMemory(BaseMemory):
|
||||
|
||||
return memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, str]:
|
||||
"""Load all vars from sub-memories."""
|
||||
memory_data: Dict[str, Any] = {}
|
||||
|
||||
# Collect vars from all sub-memories
|
||||
for memory in self.memories:
|
||||
data = memory.load_memory_variables(inputs)
|
||||
data = memory.load_memory_variables(inputs, callbacks=callbacks)
|
||||
memory_data = {
|
||||
**memory_data,
|
||||
**data,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import (
|
||||
@@ -285,7 +286,9 @@ class ConversationEntityMemory(BaseChatMemory):
|
||||
"""
|
||||
return ["entities", self.chat_history_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns chat history and all generated entities with summaries if available,
|
||||
and updates or clears the recent entity cache.
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Type, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs import NetworkxEntityGraph
|
||||
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
|
||||
@@ -34,7 +35,9 @@ class ConversationKGMemory(BaseChatMemory):
|
||||
"""Number of previous utterances to include in the context."""
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
entities = self._get_current_entities(inputs)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import get_buffer_string
|
||||
|
||||
@@ -64,7 +65,9 @@ class MotorheadMemory(BaseChatMemory):
|
||||
if context and context != "NONE":
|
||||
self.context = context
|
||||
|
||||
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
if self.return_messages:
|
||||
return {self.memory_key: self.chat_memory.messages}
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMemory
|
||||
|
||||
|
||||
@@ -13,9 +14,11 @@ class ReadOnlySharedMemory(BaseMemory):
|
||||
"""Return memory variables."""
|
||||
return self.memory.memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, str]:
|
||||
"""Load memory variables from memory."""
|
||||
return self.memory.load_memory_variables(inputs)
|
||||
return self.memory.load_memory_variables(inputs, callbacks=callbacks)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMemory
|
||||
|
||||
|
||||
@@ -14,7 +15,9 @@ class SimpleMemory(BaseMemory):
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Type
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
@@ -67,7 +68,9 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
if self.return_messages:
|
||||
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
@@ -26,7 +27,9 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer = self.buffer
|
||||
if self.moving_summary_buffer != "":
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
@@ -27,7 +28,9 @@ class ConversationTokenBufferMemory(BaseChatMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer: Any = self.buffer
|
||||
if self.return_messages:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseMemory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.schema import Document
|
||||
@@ -40,12 +41,12 @@ class VectorStoreRetrieverMemory(BaseMemory):
|
||||
return self.input_key
|
||||
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any]
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
docs = self.retriever.get_relevant_documents(query, callbacks=callbacks)
|
||||
result: Union[List[Document], str]
|
||||
if not self.return_docs:
|
||||
result = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
@@ -94,7 +94,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
||||
inputs = self.llm_chain.prep_inputs(
|
||||
{"query": query}, callbacks=run_manager.get_child()
|
||||
)
|
||||
structured_query = cast(
|
||||
StructuredQuery,
|
||||
self.llm_chain.predict_and_parse(
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import langchain
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
@@ -27,7 +28,7 @@ class BaseMemory(Serializable, ABC):
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Dict[str, Any]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
@@ -48,7 +49,11 @@ class BaseMemory(Serializable, ABC):
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
callbacks: "langchain.callbacks.manager.Callbacks" = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import RUN_KEY, BaseMemory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
@@ -18,7 +18,7 @@ class FakeMemory(BaseMemory):
|
||||
return ["baz"]
|
||||
|
||||
def load_memory_variables(
|
||||
self, inputs: Optional[Dict[str, Any]] = None
|
||||
self, inputs: Optional[Dict[str, Any]] = None, callbacks: Callbacks = None
|
||||
) -> Dict[str, str]:
|
||||
"""Return baz variable."""
|
||||
return {"baz": "foo"}
|
||||
|
||||
Reference in New Issue
Block a user