Compare commits

...

5 Commits

Author SHA1 Message Date
William FH
60313c3f62 Delete Untitled.ipynb 2023-07-27 12:48:50 -07:00
William Fu-Hinthorn
b9bd33ff30 pass child 2023-07-24 12:51:32 -07:00
William Fu-Hinthorn
574d752c7e thread more 2023-07-24 12:11:13 -07:00
William Fu-Hinthorn
914489a0a1 update exp 2023-07-24 11:56:11 -07:00
William Fu-Hinthorn
506da41b97 Add callbacks to memory 2023-07-24 11:49:33 -07:00
22 changed files with 105 additions and 36 deletions

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
from langchain.vectorstores.base import VectorStoreRetriever from langchain.vectorstores.base import VectorStoreRetriever
from pydantic import Field from pydantic import Field
@@ -19,10 +20,13 @@ class AutoGPTMemory(BaseChatMemory):
return get_prompt_input_key(inputs, self.memory_variables) return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key return self.input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
input_key = self._get_prompt_input_key(inputs) input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key] query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query) docs = self.retriever.get_relevant_documents(query, callbacks=callbacks)
return { return {
"chat_history": self.chat_memory.messages[-10:], "chat_history": self.chat_memory.messages[-10:],
"relevant_context": docs, "relevant_context": docs,

View File

@@ -3,6 +3,7 @@ import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import Callbacks
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever from langchain.retrievers import TimeWeightedVectorStoreRetriever
@@ -222,14 +223,21 @@ class GenerativeAgentMemory(BaseMemory):
return result return result
def fetch_memories( def fetch_memories(
self, observation: str, now: Optional[datetime] = None self,
observation: str,
now: Optional[datetime] = None,
callbacks: Callbacks = None,
) -> List[Document]: ) -> List[Document]:
"""Fetch related memories.""" """Fetch related memories."""
if now is not None: if now is not None:
with mock_now(now): with mock_now(now):
return self.memory_retriever.get_relevant_documents(observation) return self.memory_retriever.get_relevant_documents(
observation, callbacks=callbacks
)
else: else:
return self.memory_retriever.get_relevant_documents(observation) return self.memory_retriever.get_relevant_documents(
observation, callbacks=callbacks
)
def format_memories_detail(self, relevant_memories: List[Document]) -> str: def format_memories_detail(self, relevant_memories: List[Document]) -> str:
content = [] content = []
@@ -260,7 +268,9 @@ class GenerativeAgentMemory(BaseMemory):
"""Input keys this memory class will load dynamically.""" """Input keys this memory class will load dynamically."""
return [] return []
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain.""" """Return key-value pairs given the text input to the chain."""
queries = inputs.get(self.queries_key) queries = inputs.get(self.queries_key)
now = inputs.get(self.now_key) now = inputs.get(self.now_key)

View File

@@ -95,7 +95,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
@inputs.setter @inputs.setter
def inputs(self, inputs: Any) -> None: def inputs(self, inputs: Any) -> None:
self._inputs = self.agent_executor.prep_inputs(inputs) self._inputs = self.agent_executor.prep_inputs(inputs, callbacks=None)
@property @property
def callbacks(self) -> Callbacks: def callbacks(self) -> Callbacks:

View File

@@ -217,7 +217,6 @@ class Chain(Serializable, ABC):
A dict of named outputs. Should contain all outputs specified in A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`. `Chain.output_keys`.
""" """
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, callbacks,
self.callbacks, self.callbacks,
@@ -232,6 +231,7 @@ class Chain(Serializable, ABC):
dumpd(self), dumpd(self),
inputs, inputs,
) )
inputs = self.prep_inputs(inputs, callbacks=run_manager.get_child())
try: try:
outputs = ( outputs = (
self._call(inputs, run_manager=run_manager) self._call(inputs, run_manager=run_manager)
@@ -284,7 +284,6 @@ class Chain(Serializable, ABC):
A dict of named outputs. Should contain all outputs specified in A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`. `Chain.output_keys`.
""" """
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, callbacks,
self.callbacks, self.callbacks,
@@ -299,6 +298,7 @@ class Chain(Serializable, ABC):
dumpd(self), dumpd(self),
inputs, inputs,
) )
inputs = self.prep_inputs(inputs, callbacks=run_manager.get_child())
try: try:
outputs = ( outputs = (
await self._acall(inputs, run_manager=run_manager) await self._acall(inputs, run_manager=run_manager)
@@ -342,7 +342,9 @@ class Chain(Serializable, ABC):
else: else:
return {**inputs, **outputs} return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: def prep_inputs(
self, inputs: Union[Dict[str, Any], Any], *, callbacks: Callbacks
) -> Dict[str, str]:
"""Validate and prepare chain inputs, including adding inputs from memory. """Validate and prepare chain inputs, including adding inputs from memory.
Args: Args:
@@ -369,7 +371,9 @@ class Chain(Serializable, ABC):
) )
inputs = {list(_input_keys)[0]: inputs} inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None: if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs) external_context = self.memory.load_memory_variables(
inputs, callbacks=callbacks
)
inputs = dict(inputs, **external_context) inputs = dict(inputs, **external_context)
self._validate_inputs(inputs) self._validate_inputs(inputs)
return inputs return inputs

View File

@@ -262,11 +262,13 @@ The following is the expected answer. Use this to measure correctness:
return ["score", "reasoning"] return ["score", "reasoning"]
return ["score"] return ["score"]
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: def prep_inputs(
self, inputs: Union[Dict[str, Any], Any], *, callbacks: Callbacks
) -> Dict[str, str]:
"""Validate and prep inputs.""" """Validate and prep inputs."""
if "reference" not in inputs: if "reference" not in inputs:
inputs["reference"] = self._format_reference(inputs.get("reference")) inputs["reference"] = self._format_reference(inputs.get("reference"))
return super().prep_inputs(inputs) return super().prep_inputs(inputs, callbacks=callbacks)
def _call( def _call(
self, self,

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List
from pydantic import Field from pydantic import Field
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
from langchain.vectorstores.base import VectorStoreRetriever from langchain.vectorstores.base import VectorStoreRetriever
@@ -20,7 +21,9 @@ class AutoGPTMemory(BaseChatMemory):
return get_prompt_input_key(inputs, self.memory_variables) return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key return self.input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
input_key = self._get_prompt_input_key(inputs) input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key] query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query) docs = self.retriever.get_relevant_documents(query)

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks.manager import Callbacks
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema import BaseMemory, Document from langchain.schema import BaseMemory, Document
@@ -260,7 +261,9 @@ class GenerativeAgentMemory(BaseMemory):
"""Input keys this memory class will load dynamically.""" """Input keys this memory class will load dynamically."""
return [] return []
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain.""" """Return key-value pairs given the text input to the chain."""
queries = inputs.get(self.queries_key) queries = inputs.get(self.queries_key)
now = inputs.get(self.now_key) now = inputs.get(self.now_key)

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
from pydantic import root_validator from pydantic import root_validator
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_prompt_input_key from langchain.memory.utils import get_prompt_input_key
from langchain.schema.messages import get_buffer_string from langchain.schema.messages import get_buffer_string
@@ -34,7 +35,9 @@ class ConversationBufferMemory(BaseChatMemory):
""" """
return [self.memory_key] return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
return {self.memory_key: self.buffer} return {self.memory_key: self.buffer}
@@ -66,7 +69,9 @@ class ConversationStringBufferMemory(BaseMemory):
""" """
return [self.memory_key] return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Return history buffer.""" """Return history buffer."""
return {self.memory_key: self.buffer} return {self.memory_key: self.buffer}

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.messages import BaseMessage, get_buffer_string
@@ -26,7 +27,9 @@ class ConversationBufferWindowMemory(BaseChatMemory):
""" """
return [self.memory_key] return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Return history buffer.""" """Return history buffer."""
buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else [] buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else []

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Set
from pydantic import validator from pydantic import validator
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMemory from langchain.schema import BaseMemory
@@ -54,13 +55,15 @@ class CombinedMemory(BaseMemory):
return memory_variables return memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Load all vars from sub-memories.""" """Load all vars from sub-memories."""
memory_data: Dict[str, Any] = {} memory_data: Dict[str, Any] = {}
# Collect vars from all sub-memories # Collect vars from all sub-memories
for memory in self.memories: for memory in self.memories:
data = memory.load_memory_variables(inputs) data = memory.load_memory_variables(inputs, callbacks=callbacks)
memory_data = { memory_data = {
**memory_data, **memory_data,
**data, **data,

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import ( from langchain.memory.prompt import (
@@ -285,7 +286,9 @@ class ConversationEntityMemory(BaseChatMemory):
""" """
return ["entities", self.chat_history_key] return ["entities", self.chat_history_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
""" """
Returns chat history and all generated entities with summaries if available, Returns chat history and all generated entities with summaries if available,
and updates or clears the recent entity cache. and updates or clears the recent entity cache.

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Type, Union
from pydantic import Field from pydantic import Field
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.graphs import NetworkxEntityGraph from langchain.graphs import NetworkxEntityGraph
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
@@ -34,7 +35,9 @@ class ConversationKGMemory(BaseChatMemory):
"""Number of previous utterances to include in the context.""" """Number of previous utterances to include in the context."""
memory_key: str = "history" #: :meta private: memory_key: str = "history" #: :meta private:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
entities = self._get_current_entities(inputs) entities = self._get_current_entities(inputs)

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
import requests import requests
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import get_buffer_string from langchain.schema.messages import get_buffer_string
@@ -64,7 +65,9 @@ class MotorheadMemory(BaseChatMemory):
if context and context != "NONE": if context and context != "NONE":
self.context = context self.context = context
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
if self.return_messages: if self.return_messages:
return {self.memory_key: self.chat_memory.messages} return {self.memory_key: self.chat_memory.messages}
else: else:

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMemory from langchain.schema import BaseMemory
@@ -13,9 +14,11 @@ class ReadOnlySharedMemory(BaseMemory):
"""Return memory variables.""" """Return memory variables."""
return self.memory.memory_variables return self.memory.memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Load memory variables from memory.""" """Load memory variables from memory."""
return self.memory.load_memory_variables(inputs) return self.memory.load_memory_variables(inputs, callbacks=callbacks)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed""" """Nothing should be saved or changed"""

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMemory from langchain.schema import BaseMemory
@@ -14,7 +15,9 @@ class SimpleMemory(BaseMemory):
def memory_variables(self) -> List[str]: def memory_variables(self) -> List[str]:
return list(self.memories.keys()) return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
return self.memories return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Type
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT from langchain.memory.prompt import SUMMARY_PROMPT
@@ -67,7 +68,9 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
""" """
return [self.memory_key] return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
if self.return_messages: if self.return_messages:
buffer: Any = [self.summary_message_cls(content=self.buffer)] buffer: Any = [self.summary_message_cls(content=self.buffer)]

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List
from pydantic import root_validator from pydantic import root_validator
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin from langchain.memory.summary import SummarizerMixin
from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.messages import BaseMessage, get_buffer_string
@@ -26,7 +27,9 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
""" """
return [self.memory_key] return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
buffer = self.buffer buffer = self.buffer
if self.moving_summary_buffer != "": if self.moving_summary_buffer != "":

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.messages import BaseMessage, get_buffer_string
@@ -27,7 +28,9 @@ class ConversationTokenBufferMemory(BaseChatMemory):
""" """
return [self.memory_key] return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
buffer: Any = self.buffer buffer: Any = self.buffer
if self.return_messages: if self.return_messages:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
from pydantic import Field from pydantic import Field
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseMemory from langchain.memory.chat_memory import BaseMemory
from langchain.memory.utils import get_prompt_input_key from langchain.memory.utils import get_prompt_input_key
from langchain.schema import Document from langchain.schema import Document
@@ -40,12 +41,12 @@ class VectorStoreRetrieverMemory(BaseMemory):
return self.input_key return self.input_key
def load_memory_variables( def load_memory_variables(
self, inputs: Dict[str, Any] self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Union[List[Document], str]]: ) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer.""" """Return history buffer."""
input_key = self._get_prompt_input_key(inputs) input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key] query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query) docs = self.retriever.get_relevant_documents(query, callbacks=callbacks)
result: Union[List[Document], str] result: Union[List[Document], str]
if not self.return_docs: if not self.return_docs:
result = "\n".join([doc.page_content for doc in docs]) result = "\n".join([doc.page_content for doc in docs])

View File

@@ -94,7 +94,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
Returns: Returns:
List of relevant documents List of relevant documents
""" """
inputs = self.llm_chain.prep_inputs({"query": query}) inputs = self.llm_chain.prep_inputs(
{"query": query}, callbacks=run_manager.get_child()
)
structured_query = cast( structured_query = cast(
StructuredQuery, StructuredQuery,
self.llm_chain.predict_and_parse( self.llm_chain.predict_and_parse(

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
import langchain
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
@@ -27,7 +28,7 @@ class BaseMemory(Serializable, ABC):
def memory_variables(self) -> List[str]: def memory_variables(self) -> List[str]:
return list(self.memories.keys()) return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Dict[str, Any]:
return self.memories return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
@@ -48,7 +49,11 @@ class BaseMemory(Serializable, ABC):
"""The string keys this memory class will add to chain inputs.""" """The string keys this memory class will add to chain inputs."""
@abstractmethod @abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(
self,
inputs: Dict[str, Any],
callbacks: "langchain.callbacks.manager.Callbacks" = None,
) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain.""" """Return key-value pairs given the text input to the chain."""
@abstractmethod @abstractmethod

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import pytest import pytest
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import RUN_KEY, BaseMemory from langchain.schema import RUN_KEY, BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@@ -18,7 +18,7 @@ class FakeMemory(BaseMemory):
return ["baz"] return ["baz"]
def load_memory_variables( def load_memory_variables(
self, inputs: Optional[Dict[str, Any]] = None self, inputs: Optional[Dict[str, Any]] = None, callbacks: Callbacks = None
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Return baz variable.""" """Return baz variable."""
return {"baz": "foo"} return {"baz": "foo"}