Compare commits

...

5 Commits

Author SHA1 Message Date
William FH
60313c3f62 Delete Untitled.ipynb 2023-07-27 12:48:50 -07:00
William Fu-Hinthorn
b9bd33ff30 pass child 2023-07-24 12:51:32 -07:00
William Fu-Hinthorn
574d752c7e thread more 2023-07-24 12:11:13 -07:00
William Fu-Hinthorn
914489a0a1 update exp 2023-07-24 11:56:11 -07:00
William Fu-Hinthorn
506da41b97 Add callbacks to memory 2023-07-24 11:49:33 -07:00
22 changed files with 105 additions and 36 deletions

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
from langchain.vectorstores.base import VectorStoreRetriever
from pydantic import Field
@@ -19,10 +20,13 @@ class AutoGPTMemory(BaseChatMemory):
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query)
docs = self.retriever.get_relevant_documents(query, callbacks=callbacks)
return {
"chat_history": self.chat_memory.messages[-10:],
"relevant_context": docs,

View File

@@ -3,6 +3,7 @@ import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import Callbacks
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever
@@ -222,14 +223,21 @@ class GenerativeAgentMemory(BaseMemory):
return result
def fetch_memories(
self, observation: str, now: Optional[datetime] = None
self,
observation: str,
now: Optional[datetime] = None,
callbacks: Callbacks = None,
) -> List[Document]:
"""Fetch related memories."""
if now is not None:
with mock_now(now):
return self.memory_retriever.get_relevant_documents(observation)
return self.memory_retriever.get_relevant_documents(
observation, callbacks=callbacks
)
else:
return self.memory_retriever.get_relevant_documents(observation)
return self.memory_retriever.get_relevant_documents(
observation, callbacks=callbacks
)
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
content = []
@@ -260,7 +268,9 @@ class GenerativeAgentMemory(BaseMemory):
"""Input keys this memory class will load dynamically."""
return []
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
queries = inputs.get(self.queries_key)
now = inputs.get(self.now_key)

View File

@@ -95,7 +95,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
@inputs.setter
def inputs(self, inputs: Any) -> None:
self._inputs = self.agent_executor.prep_inputs(inputs)
self._inputs = self.agent_executor.prep_inputs(inputs, callbacks=None)
@property
def callbacks(self) -> Callbacks:

View File

@@ -217,7 +217,6 @@ class Chain(Serializable, ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
@@ -232,6 +231,7 @@ class Chain(Serializable, ABC):
dumpd(self),
inputs,
)
inputs = self.prep_inputs(inputs, callbacks=run_manager.get_child())
try:
outputs = (
self._call(inputs, run_manager=run_manager)
@@ -284,7 +284,6 @@ class Chain(Serializable, ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
@@ -299,6 +298,7 @@ class Chain(Serializable, ABC):
dumpd(self),
inputs,
)
inputs = self.prep_inputs(inputs, callbacks=run_manager.get_child())
try:
outputs = (
await self._acall(inputs, run_manager=run_manager)
@@ -342,7 +342,9 @@ class Chain(Serializable, ABC):
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any], *, callbacks: Callbacks
) -> Dict[str, str]:
"""Validate and prepare chain inputs, including adding inputs from memory.
Args:
@@ -369,7 +371,9 @@ class Chain(Serializable, ABC):
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
external_context = self.memory.load_memory_variables(
inputs, callbacks=callbacks
)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs

View File

@@ -262,11 +262,13 @@ The following is the expected answer. Use this to measure correctness:
return ["score", "reasoning"]
return ["score"]
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any], *, callbacks: Callbacks
) -> Dict[str, str]:
"""Validate and prep inputs."""
if "reference" not in inputs:
inputs["reference"] = self._format_reference(inputs.get("reference"))
return super().prep_inputs(inputs)
return super().prep_inputs(inputs, callbacks=callbacks)
def _call(
self,

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List
from pydantic import Field
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
from langchain.vectorstores.base import VectorStoreRetriever
@@ -20,7 +21,9 @@ class AutoGPTMemory(BaseChatMemory):
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query)

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain import LLMChain
from langchain.callbacks.manager import Callbacks
from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema import BaseMemory, Document
@@ -260,7 +261,9 @@ class GenerativeAgentMemory(BaseMemory):
"""Input keys this memory class will load dynamically."""
return []
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain."""
queries = inputs.get(self.queries_key)
now = inputs.get(self.now_key)

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_prompt_input_key
from langchain.schema.messages import get_buffer_string
@@ -34,7 +35,9 @@ class ConversationBufferMemory(BaseChatMemory):
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
@@ -66,7 +69,9 @@ class ConversationStringBufferMemory(BaseMemory):
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Return history buffer."""
return {self.memory_key: self.buffer}

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string
@@ -26,7 +27,9 @@ class ConversationBufferWindowMemory(BaseChatMemory):
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Return history buffer."""
buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else []

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Set
from pydantic import validator
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMemory
@@ -54,13 +55,15 @@ class CombinedMemory(BaseMemory):
return memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Load all vars from sub-memories."""
memory_data: Dict[str, Any] = {}
# Collect vars from all sub-memories
for memory in self.memories:
data = memory.load_memory_variables(inputs)
data = memory.load_memory_variables(inputs, callbacks=callbacks)
memory_data = {
**memory_data,
**data,

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel, Field
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
@@ -285,7 +286,9 @@ class ConversationEntityMemory(BaseChatMemory):
"""
return ["entities", self.chat_history_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""
Returns chat history and all generated entities with summaries if available,
and updates or clears the recent entity cache.

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Type, Union
from pydantic import Field
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.graphs import NetworkxEntityGraph
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
@@ -34,7 +35,9 @@ class ConversationKGMemory(BaseChatMemory):
"""Number of previous utterances to include in the context."""
memory_key: str = "history" #: :meta private:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer."""
entities = self._get_current_entities(inputs)

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
import requests
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import get_buffer_string
@@ -64,7 +65,9 @@ class MotorheadMemory(BaseChatMemory):
if context and context != "NONE":
self.context = context
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
if self.return_messages:
return {self.memory_key: self.chat_memory.messages}
else:

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMemory
@@ -13,9 +14,11 @@ class ReadOnlySharedMemory(BaseMemory):
"""Return memory variables."""
return self.memory.memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
"""Load memory variables from memory."""
return self.memory.load_memory_variables(inputs)
return self.memory.load_memory_variables(inputs, callbacks=callbacks)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed"""

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMemory
@@ -14,7 +15,9 @@ class SimpleMemory(BaseMemory):
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Type
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
@@ -67,7 +68,9 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer."""
if self.return_messages:
buffer: Any = [self.summary_message_cls(content=self.buffer)]

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List
from pydantic import root_validator
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin
from langchain.schema.messages import BaseMessage, get_buffer_string
@@ -26,7 +27,9 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer."""
buffer = self.buffer
if self.moving_summary_buffer != "":

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, get_buffer_string
@@ -27,7 +28,9 @@ class ConversationTokenBufferMemory(BaseChatMemory):
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Any]:
"""Return history buffer."""
buffer: Any = self.buffer
if self.return_messages:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
from pydantic import Field
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseMemory
from langchain.memory.utils import get_prompt_input_key
from langchain.schema import Document
@@ -40,12 +41,12 @@ class VectorStoreRetrieverMemory(BaseMemory):
return self.input_key
def load_memory_variables(
self, inputs: Dict[str, Any]
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query)
docs = self.retriever.get_relevant_documents(query, callbacks=callbacks)
result: Union[List[Document], str]
if not self.return_docs:
result = "\n".join([doc.page_content for doc in docs])

View File

@@ -94,7 +94,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs({"query": query})
inputs = self.llm_chain.prep_inputs(
{"query": query}, callbacks=run_manager.get_child()
)
structured_query = cast(
StructuredQuery,
self.llm_chain.predict_and_parse(

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import langchain
from langchain.load.serializable import Serializable
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
@@ -27,7 +28,7 @@ class BaseMemory(Serializable, ABC):
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Dict[str, Any]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
@@ -48,7 +49,11 @@ class BaseMemory(Serializable, ABC):
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(
self,
inputs: Dict[str, Any],
callbacks: "langchain.callbacks.manager.Callbacks" = None,
) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
@abstractmethod

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import pytest
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.base import Chain
from langchain.schema import RUN_KEY, BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@@ -18,7 +18,7 @@ class FakeMemory(BaseMemory):
return ["baz"]
def load_memory_variables(
self, inputs: Optional[Dict[str, Any]] = None
self, inputs: Optional[Dict[str, Any]] = None, callbacks: Callbacks = None
) -> Dict[str, str]:
"""Return baz variable."""
return {"baz": "foo"}